From 4f6971e646d4650ebf77e28a000159f37ed01949 Mon Sep 17 00:00:00 2001 From: prabhu Date: Fri, 10 Nov 2023 08:59:06 +0000 Subject: [PATCH] Convert to scala 3 format (#32) Signed-off-by: Prabhu Subramanian --- .scalafmt.conf | 13 +- build.sbt | 2 +- codemeta.json | 2 +- .../io/appthreat/console/BridgeBase.scala | 693 +- .../scala/io/appthreat/console/Commit.scala | 36 +- .../scala/io/appthreat/console/Console.scala | 1241 ++-- .../io/appthreat/console/ConsoleConfig.scala | 106 +- .../io/appthreat/console/CpgConverter.scala | 14 +- .../scala/io/appthreat/console/Help.scala | 83 +- .../scala/io/appthreat/console/JProduct.scala | 6 +- .../io/appthreat/console/PluginManager.scala | 161 +- .../io/appthreat/console/Reporting.scala | 55 +- .../main/scala/io/appthreat/console/Run.scala | 115 +- .../console/cpgcreation/AtomGenerator.scala | 45 +- .../console/cpgcreation/CCpgGenerator.scala | 28 +- .../console/cpgcreation/CdxGenerator.scala | 36 +- .../console/cpgcreation/CpgGenerator.scala | 77 +- .../cpgcreation/CpgGeneratorFactory.scala | 144 +- .../console/cpgcreation/ImportCode.scala | 356 +- .../cpgcreation/JavaSrcCpgGenerator.scala | 38 +- .../console/cpgcreation/JsCpgGenerator.scala | 24 +- .../cpgcreation/JsSrcCpgGenerator.scala | 36 +- .../cpgcreation/PythonSrcCpgGenerator.scala | 75 +- .../console/cpgcreation/package.scala | 193 +- .../scala/io/appthreat/console/package.scala | 73 +- .../console/workspacehandling/Project.scala | 77 +- .../console/workspacehandling/Workspace.scala | 43 +- .../workspacehandling/WorkspaceLoader.scala | 84 +- .../workspacehandling/WorkspaceManager.scala | 749 +- .../dataflowengineoss/DefaultSemantics.scala | 308 +- .../dotgenerator/DdgGenerator.scala | 204 +- .../dotgenerator/DotCpg14Generator.scala | 30 +- .../dotgenerator/DotDdgGenerator.scala | 19 +- .../dotgenerator/DotPdgGenerator.scala | 19 +- .../language/ExtendedCfgNode.scala | 177 +- .../dataflowengineoss/language/Path.scala | 395 +- .../language/dotextension/DdgNodeDot.scala | 41 +- .../nodemethods/ExpressionMethods.scala | 175 +- .../nodemethods/ExtendedCfgNodeMethods.scala | 141 +- .../dataflowengineoss/language/package.scala | 33 +- .../layers/dataflows/DumpCpg14.scala | 32 +- .../layers/dataflows/DumpDdg.scala | 32 +- .../layers/dataflows/DumpPdg.scala | 32 +- .../layers/dataflows/OssDataFlow.scala | 29 +- .../appthreat/dataflowengineoss/package.scala | 25 +- .../passes/reachingdef/DataFlowProblem.scala | 55 +- .../passes/reachingdef/DataFlowSolver.scala | 135 +- .../passes/reachingdef/DdgGenerator.scala | 697 +- .../passes/reachingdef/EdgeValidator.scala | 88 +- .../passes/reachingdef/ReachingDefPass.scala | 83 +- .../reachingdef/ReachingDefProblem.scala | 670 +- .../passes/reachingdef/package.scala | 5 +- .../queryengine/AccessPathUsage.scala | 75 +- .../queryengine/Engine.scala | 595 +- .../queryengine/HeldTaskCompletion.scala | 314 +- .../queryengine/SourcesToStartingPoints.scala | 401 +- .../queryengine/TaskCreator.scala | 376 +- .../queryengine/TaskSolver.scala | 404 +- .../queryengine/package.scala | 200 +- .../semanticsloader/Parser.scala | 295 +- .../scala/io/appthreat/console/Query.scala | 60 +- .../io/appthreat/console/QueryDatabase.scala | 151 +- .../io/appthreat/macros/QueryMacros.scala | 22 +- .../main/scala/io/appthreat/c2cpg/C2Cpg.scala | 30 +- .../main/scala/io/appthreat/c2cpg/Main.scala | 236 +- .../c2cpg/astcreation/AstCreator.scala | 121 +- .../c2cpg/astcreation/AstCreatorHelper.scala | 1060 ++- .../AstForExpressionsCreator.scala | 678 +- .../astcreation/AstForFunctionsCreator.scala | 457 +- .../astcreation/AstForPrimitivesCreator.scala | 224 +- .../astcreation/AstForStatementsCreator.scala | 437 +- .../astcreation/AstForTypesCreator.scala | 823 ++- .../c2cpg/astcreation/AstNodeBuilder.scala | 76 +- .../appthreat/c2cpg/astcreation/Defines.scala | 11 +- .../c2cpg/astcreation/MacroHandler.scala | 343 +- .../c2cpg/datastructures/CGlobal.scala | 15 +- .../io/appthreat/c2cpg/parser/CdtParser.scala | 275 +- .../parser/CustomFileContentProvider.scala | 68 +- .../c2cpg/parser/DefaultDefines.scala | 66 +- .../appthreat/c2cpg/parser/FileDefaults.scala | 30 +- .../c2cpg/parser/HeaderFileFinder.scala | 36 +- .../c2cpg/parser/ParseProblemsLogger.scala | 52 +- .../appthreat/c2cpg/parser/ParserConfig.scala | 57 +- .../parser/PreprocessorStatementsLogger.scala | 53 +- .../c2cpg/passes/AstCreationPass.scala | 78 +- .../c2cpg/passes/ConfigFileCreationPass.scala | 41 +- .../c2cpg/passes/PreprocessorPass.scala | 39 +- .../c2cpg/passes/TypeDeclNodePass.scala | 95 +- .../c2cpg/utils/ExternalCommand.scala | 40 +- .../c2cpg/utils/IncludeAutoDiscovery.scala | 179 +- .../io/appthreat/c2cpg/utils/Report.scala | 165 +- .../io/appthreat/c2cpg/utils/TimeUtils.scala | 103 +- .../appthreat/javasrc2cpg/JavaSrc2Cpg.scala | 111 +- .../scala/io/appthreat/javasrc2cpg/Main.scala | 218 +- .../jartypereader/JarTypeReader.scala | 228 +- .../descriptorparser/DescriptorParser.scala | 51 +- .../descriptorparser/TokenParser.scala | 99 +- .../descriptorparser/TypeParser.scala | 177 +- .../jartypereader/model/Model.scala | 179 +- .../jpastprinter/JavaParserAstPrinter.scala | 24 +- .../javasrc2cpg/passes/AstCreationPass.scala | 245 +- .../javasrc2cpg/passes/AstCreator.scala | 6368 +++++++++-------- .../passes/ConfigFileCreationPass.scala | 80 +- .../passes/JavaTypeHintCallLinker.scala | 21 +- .../passes/JavaTypeRecoveryPass.scala | 103 +- .../passes/TypeInferencePass.scala | 198 +- .../javasrc2cpg/scope/JavaScopeElement.scala | 122 +- .../appthreat/javasrc2cpg/scope/Scope.scala | 431 +- .../javasrc2cpg/scope/TypeDeclContainer.scala | 23 +- .../typesolvers/EagerSourceTypeSolver.scala | 107 +- .../typesolvers/JmodClassPath.scala | 76 +- .../typesolvers/NonCachingClassPool.scala | 13 +- .../SimpleCombinedTypeSolver.scala | 146 +- .../typesolvers/TypeInfoCalculator.scala | 478 +- .../typesolvers/TypeSizeReducer.scala | 22 +- .../noncaching/JdkJarTypeSolver.scala | 369 +- .../javasrc2cpg/util/BindingTable.scala | 154 +- .../util/BindingTableAdapterImpls.scala | 228 +- .../appthreat/javasrc2cpg/util/Delombok.scala | 133 +- .../javasrc2cpg/util/NameConstants.scala | 9 +- .../javasrc2cpg/util/SourceParser.scala | 241 +- .../javasrc2cpg/util/SourceRootFinder.scala | 151 +- .../io/appthreat/javasrc2cpg/util/Util.scala | 105 +- .../io/appthreat/jimple2cpg/Jimple2Cpg.scala | 279 +- .../scala/io/appthreat/jimple2cpg/Main.scala | 101 +- .../jimple2cpg/passes/AstCreationPass.scala | 38 +- .../jimple2cpg/passes/AstCreator.scala | 2256 +++--- .../passes/ConfigFileCreationPass.scala | 61 +- .../passes/DeclarationRefPass.scala | 22 +- .../passes/SootAstCreationPass.scala | 34 +- .../jimple2cpg/util/ProgramHandlingUtil.scala | 358 +- .../io/appthreat/jssrc2cpg/JsSrc2Cpg.scala | 115 +- .../scala/io/appthreat/jssrc2cpg/Main.scala | 67 +- .../jssrc2cpg/astcreation/AstCreator.scala | 407 +- .../astcreation/AstCreatorHelper.scala | 590 +- .../AstForDeclarationsCreator.scala | 1553 ++-- .../AstForExpressionsCreator.scala | 1043 +-- .../astcreation/AstForFunctionsCreator.scala | 976 +-- .../astcreation/AstForPrimitivesCreator.scala | 185 +- .../astcreation/AstForStatementsCreator.scala | 1941 ++--- .../AstForTemplateDomCreator.scala | 201 +- .../astcreation/AstForTypesCreator.scala | 1093 +-- .../astcreation/AstNodeBuilder.scala | 517 +- .../jssrc2cpg/astcreation/TypeHelper.scala | 269 +- .../datastructures/PendingReference.scala | 28 +- .../jssrc2cpg/datastructures/Scope.scala | 200 +- .../datastructures/ScopeElement.scala | 15 +- .../appthreat/jssrc2cpg/parser/BabelAst.scala | 539 +- .../jssrc2cpg/parser/BabelJsonParser.scala | 49 +- .../jssrc2cpg/passes/AstCreationPass.scala | 91 +- .../jssrc2cpg/passes/BuiltinTypesPass.scala | 65 +- .../jssrc2cpg/passes/ConfigPass.scala | 74 +- .../jssrc2cpg/passes/ConstClosurePass.scala | 112 +- .../appthreat/jssrc2cpg/passes/Defines.scala | 55 +- .../jssrc2cpg/passes/DependenciesPass.scala | 39 +- .../jssrc2cpg/passes/EcmaBuiltins.scala | 5 +- .../jssrc2cpg/passes/GlobalBuiltins.scala | 2181 +++--- .../jssrc2cpg/passes/ImportResolverPass.scala | 221 +- .../jssrc2cpg/passes/ImportsPass.scala | 24 +- .../JavaScriptInheritanceNamePass.scala | 14 +- .../passes/JavaScriptTypeHintCallLinker.scala | 12 +- .../passes/JavaScriptTypeRecovery.scala | 393 +- .../jssrc2cpg/passes/JsMetaDataPass.scala | 15 +- .../jssrc2cpg/passes/PrivateKeyFilePass.scala | 22 +- .../jssrc2cpg/passes/TypeNodePass.scala | 20 +- .../preprocessing/EjsPreprocessor.scala | 135 +- .../jssrc2cpg/utils/AstGenRunner.scala | 531 +- .../jssrc2cpg/utils/PackageJsonParser.scala | 145 +- .../io/appthreat/pysrc2cpg/AutoIncIndex.scala | 12 +- .../io/appthreat/pysrc2cpg/CodeToCpg.scala | 51 +- .../pysrc2cpg/ConfigFileCreationPass.scala | 34 +- .../io/appthreat/pysrc2cpg/Constants.scala | 7 +- .../io/appthreat/pysrc2cpg/ContextStack.scala | 827 ++- .../DependenciesFromRequirementsTxtPass.scala | 33 +- .../DynamicTypeHintFullNamePass.scala | 177 +- .../io/appthreat/pysrc2cpg/EdgeBuilder.scala | 190 +- .../pysrc2cpg/ImportResolverPass.scala | 293 +- .../io/appthreat/pysrc2cpg/ImportsPass.scala | 29 +- .../scala/io/appthreat/pysrc2cpg/Main.scala | 68 +- .../io/appthreat/pysrc2cpg/NodeBuilder.scala | 629 +- .../io/appthreat/pysrc2cpg/NodeToCode.scala | 8 +- .../scala/io/appthreat/pysrc2cpg/Py2Cpg.scala | 56 +- .../pysrc2cpg/Py2CpgOnFileSystem.scala | 171 +- .../pysrc2cpg/PythonAstVisitor.scala | 4528 ++++++------ .../pysrc2cpg/PythonAstVisitorHelpers.scala | 1236 ++-- .../pysrc2cpg/PythonInheritanceNamePass.scala | 12 +- .../pysrc2cpg/PythonTypeHintCallLinker.scala | 35 +- .../pysrc2cpg/PythonTypeRecovery.scala | 432 +- .../memop/AstNodeToMemoryOperationMap.scala | 41 +- .../pysrc2cpg/memop/MemoryOperation.scala | 5 +- .../memop/MemoryOperationCalculator.scala | 957 ++- .../appthreat/pythonparser/AstPrinter.scala | 1607 ++--- .../appthreat/pythonparser/AstVisitor.scala | 480 +- .../pythonparser/CharStreamImpl.scala | 419 +- .../io/appthreat/pythonparser/PyParser.scala | 30 +- .../io/appthreat/pythonparser/ast/Ast.scala | 1671 ++--- .../pythonparser/ast/AttributeProvider.scala | 86 +- .../appthreat/pythonparser/ast/package.scala | 5 +- .../main/scala/io/appthreat/x2cpg/Ast.scala | 466 +- .../io/appthreat/x2cpg/AstCreatorBase.scala | 643 +- .../io/appthreat/x2cpg/AstNodeBuilder.scala | 589 +- .../scala/io/appthreat/x2cpg/Defines.scala | 50 +- .../scala/io/appthreat/x2cpg/Imports.scala | 29 +- .../io/appthreat/x2cpg/SourceFiles.scala | 286 +- .../main/scala/io/appthreat/x2cpg/X2Cpg.scala | 563 +- .../x2cpg/datastructures/Global.scala | 6 +- .../x2cpg/datastructures/Scope.scala | 49 +- .../x2cpg/datastructures/ScopeElement.scala | 8 +- .../x2cpg/datastructures/Stack.scala | 19 +- .../io/appthreat/x2cpg/layers/Base.scala | 58 +- .../io/appthreat/x2cpg/layers/CallGraph.scala | 39 +- .../appthreat/x2cpg/layers/ControlFlow.scala | 50 +- .../io/appthreat/x2cpg/layers/DumpAst.scala | 34 +- .../io/appthreat/x2cpg/layers/DumpCdg.scala | 33 +- .../io/appthreat/x2cpg/layers/DumpCfg.scala | 33 +- .../x2cpg/layers/TypeRelations.scala | 36 +- .../x2cpg/passes/base/AstLinkerPass.scala | 102 +- .../x2cpg/passes/base/ContainsEdgePass.scala | 73 +- .../x2cpg/passes/base/FileCreationPass.scala | 87 +- .../passes/base/MethodDecoratorPass.scala | 92 +- .../x2cpg/passes/base/MethodStubCreator.scala | 247 +- .../x2cpg/passes/base/NamespaceCreator.scala | 27 +- .../base/ParameterIndexCompatPass.scala | 23 +- .../passes/base/TypeDeclStubCreator.scala | 93 +- .../x2cpg/passes/base/TypeUsagePass.scala | 78 +- .../passes/callgraph/DynamicCallLinker.scala | 420 +- .../passes/callgraph/MethodRefLinker.scala | 35 +- .../passes/callgraph/NaiveCallLinker.scala | 32 +- .../passes/callgraph/StaticCallLinker.scala | 102 +- .../passes/controlflow/CfgCreationPass.scala | 22 +- .../passes/controlflow/cfgcreation/Cfg.scala | 316 +- .../controlflow/cfgcreation/CfgCreator.scala | 1167 +-- .../controlflow/cfgdominator/CfgAdapter.scala | 7 +- .../cfgdominator/CfgDominator.scala | 144 +- .../cfgdominator/CfgDominatorFrontier.scala | 59 +- .../cfgdominator/CfgDominatorPass.scala | 70 +- .../cfgdominator/CpgCfgAdapter.scala | 12 +- .../cfgdominator/DomTreeAdapter.scala | 14 +- .../cfgdominator/ReverseCpgCfgAdapter.scala | 12 +- .../controlflow/codepencegraph/CdgPass.scala | 104 +- .../CpgPostDomTreeAdapter.scala | 8 +- .../x2cpg/passes/frontend/Dereference.scala | 38 +- .../x2cpg/passes/frontend/MetaDataPass.scala | 56 +- .../x2cpg/passes/frontend/SymbolTable.scala | 208 +- .../x2cpg/passes/frontend/TypeNodePass.scala | 136 +- .../frontend/XConfigFileCreationPass.scala | 121 +- .../passes/frontend/XImportResolverPass.scala | 233 +- .../x2cpg/passes/frontend/XImportsPass.scala | 31 +- .../frontend/XInheritanceFullNamePass.scala | 270 +- .../passes/frontend/XTypeHintCallLinker.scala | 343 +- .../x2cpg/passes/frontend/XTypeRecovery.scala | 2289 +++--- .../x2cpg/passes/taggers/CdxPass.scala | 394 +- .../passes/taggers/ChennaiTagsPass.scala | 199 +- .../typerelations/AliasLinkerPass.scala | 32 +- .../typerelations/TypeHierarchyPass.scala | 38 +- .../x2cpg/utils/AstPropertiesUtil.scala | 32 +- .../appthreat/x2cpg/utils/Environment.scala | 55 +- .../x2cpg/utils/ExternalCommand.scala | 94 +- .../io/appthreat/x2cpg/utils/HashUtil.scala | 44 +- .../appthreat/x2cpg/utils/LinkingUtil.scala | 300 +- .../io/appthreat/x2cpg/utils/ListUtils.scala | 31 +- .../appthreat/x2cpg/utils/NodeBuilders.scala | 347 +- .../io/appthreat/x2cpg/utils/Report.scala | 165 +- .../appthreat/x2cpg/utils/StringUtils.scala | 8 +- .../io/appthreat/x2cpg/utils/TimeUtils.scala | 106 +- .../utils/dependency/DependencyResolver.scala | 225 +- .../utils/dependency/GradleDependencies.scala | 494 +- .../utils/dependency/MavenCoordinates.scala | 39 +- .../utils/dependency/MavenDependencies.scala | 63 +- .../io/appthreat/chencli/ChenExport.scala | 416 +- .../scala/io/appthreat/chencli/ChenFlow.scala | 172 +- .../io/appthreat/chencli/ChenParse.scala | 351 +- .../io/appthreat/chencli/ChenVectors.scala | 315 +- .../io/appthreat/chencli/CpgBasedTool.scala | 110 +- .../appthreat/chencli/DefaultOverlays.scala | 37 +- .../chencli/console/ChenConsole.scala | 63 +- .../chencli/console/Predefined.scala | 54 +- .../chencli/console/ReplBridge.scala | 25 +- pyproject.toml | 2 +- .../io/shiftleft/semanticcpg/Overlays.scala | 67 +- .../accesspath/AccessElement.scala | 57 +- .../semanticcpg/accesspath/AccessPath.scala | 815 +-- .../semanticcpg/accesspath/TrackedBase.scala | 62 +- .../semanticcpg/codedumper/CodeDumper.scala | 214 +- .../codedumper/SourceHighlighter.scala | 57 +- .../dotgenerator/AstGenerator.scala | 27 +- .../dotgenerator/CallGraphGenerator.scala | 54 +- .../dotgenerator/CdgGenerator.scala | 12 +- .../dotgenerator/CfgGenerator.scala | 109 +- .../dotgenerator/DotAstGenerator.scala | 15 +- .../dotgenerator/DotCallGraphGenerator.scala | 11 +- .../dotgenerator/DotCdgGenerator.scala | 15 +- .../dotgenerator/DotCfgGenerator.scala | 15 +- .../dotgenerator/DotSerializer.scala | 274 +- .../DotTypeHierarchyGenerator.scala | 11 +- .../dotgenerator/TypeHierarchyGenerator.scala | 68 +- .../language/AccessPathHandling.scala | 261 +- .../semanticcpg/language/HasLocation.scala | 5 +- .../semanticcpg/language/ICallResolver.scala | 156 +- .../language/LocationCreator.scala | 124 +- .../semanticcpg/language/NewNodeSteps.scala | 19 +- .../language/NewTagNodePairTraversal.scala | 27 +- .../language/NodeExtensionFinder.scala | 72 +- .../semanticcpg/language/NodeOrdering.scala | 95 +- .../semanticcpg/language/NodeSteps.scala | 156 +- .../language/NodeTypeStarters.scala | 614 +- .../shiftleft/semanticcpg/language/Show.scala | 63 +- .../semanticcpg/language/Steps.scala | 191 +- .../semanticcpg/language/TagTraversal.scala | 27 +- .../android/ConfigFileTraversal.scala | 175 +- .../language/android/Constants.scala | 7 +- .../language/android/LocalTraversal.scala | 38 +- .../language/android/MethodTraversal.scala | 7 +- .../language/android/NodeTypeStarters.scala | 66 +- .../language/android/package.scala | 33 +- .../bindingextension/MethodTraversal.scala | 19 +- .../bindingextension/TypeDeclTraversal.scala | 20 +- .../callgraphextension/CallTraversal.scala | 18 +- .../callgraphextension/MethodTraversal.scala | 107 +- .../language/dotextension/AstNodeDot.scala | 11 +- .../language/dotextension/CfgNodeDot.scala | 18 +- .../dotextension/InterproceduralNodeDot.scala | 8 +- .../language/dotextension/Shared.scala | 55 +- .../language/nodemethods/AstNodeMethods.scala | 206 +- .../language/nodemethods/CallMethods.scala | 39 +- .../language/nodemethods/CfgNodeMethods.scala | 285 +- .../nodemethods/ExpressionMethods.scala | 96 +- .../nodemethods/IdentifierMethods.scala | 17 +- .../language/nodemethods/LiteralMethods.scala | 11 +- .../language/nodemethods/LocalMethods.scala | 16 +- .../language/nodemethods/MethodMethods.scala | 95 +- .../MethodParameterInMethods.scala | 9 +- .../MethodParameterOutMethods.scala | 15 +- .../nodemethods/MethodRefMethods.scala | 20 +- .../nodemethods/MethodReturnMethods.scala | 28 +- .../language/nodemethods/NodeMethods.scala | 15 +- .../nodemethods/StoredNodeMethods.scala | 16 +- .../ArrayAccessTraversal.scala | 32 +- .../AssignmentTraversal.scala | 11 +- .../FieldAccessTraversal.scala | 26 +- .../operatorextension/Implicits.scala | 56 +- .../operatorextension/NodeTypeStarters.scala | 55 +- .../OpAstNodeTraversal.scala | 39 +- .../language/operatorextension/OpNodes.scala | 11 +- .../operatorextension/TargetTraversal.scala | 18 +- .../nodemethods/ArrayAccessMethods.scala | 36 +- .../nodemethods/AssignmentMethods.scala | 21 +- .../nodemethods/FieldAccessMethods.scala | 26 +- .../nodemethods/OpAstNodeMethods.scala | 55 +- .../nodemethods/TargetMethods.scala | 22 +- .../language/operatorextension/package.scala | 107 +- .../semanticcpg/language/package.scala | 573 +- .../types/expressions/CallTraversal.scala | 69 +- .../ControlStructureTraversal.scala | 104 +- .../expressions/IdentifierTraversal.scala | 13 +- .../generalizations/AstNodeTraversal.scala | 424 +- .../generalizations/CfgNodeTraversal.scala | 190 +- .../DeclarationTraversal.scala | 33 +- .../generalizations/ExpressionTraversal.scala | 113 +- .../propertyaccessors/EvalTypeAccessors.scala | 75 +- .../propertyaccessors/ModifierAccessors.scala | 64 +- .../AnnotationParameterAssignTraversal.scala | 26 +- .../types/structure/AnnotationTraversal.scala | 46 +- .../types/structure/DependencyTraversal.scala | 5 +- .../types/structure/FileTraversal.scala | 15 +- .../types/structure/ImportTraversal.scala | 14 +- .../types/structure/LocalTraversal.scala | 21 +- .../types/structure/MemberTraversal.scala | 22 +- .../MethodParameterOutTraversal.scala | 55 +- .../structure/MethodParameterTraversal.scala | 55 +- .../structure/MethodReturnTraversal.scala | 34 +- .../types/structure/MethodTraversal.scala | 411 +- .../structure/NamespaceBlockTraversal.scala | 23 +- .../types/structure/NamespaceTraversal.scala | 52 +- .../types/structure/TypeDeclTraversal.scala | 211 +- .../types/structure/TypeTraversal.scala | 175 +- .../semanticcpg/layers/LayerCreator.scala | 98 +- .../io/shiftleft/semanticcpg/package.scala | 46 +- .../semanticcpg/testing/DummyNode.scala | 88 +- .../semanticcpg/testing/package.scala | 410 +- .../semanticcpg/utils/Fingerprinting.scala | 9 +- .../semanticcpg/utils/MemberAccess.scala | 61 +- .../semanticcpg/utils/SecureXmlParsing.scala | 32 +- .../semanticcpg/utils/Statements.scala | 9 +- .../shiftleft/semanticcpg/utils/Torch.scala | 151 +- 385 files changed, 41737 insertions(+), 40862 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 885f129d..160d141f 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,5 +1,14 @@ -version = 3.7.11 +version = 3.7.15 runner.dialect = scala3 preset = IntelliJ -maxColumn = 120 +maxColumn = 100 align.preset = true + +indent.main = 4 + +newlines.source = keep +rewrite.scala3.convertToNewSyntax = true +rewrite.scala3.removeOptionalBraces = yes +rewrite.scala3.insertEndMarkerMinLines = 20 +rewrite.scala3.removeEndMarkerMaxLines = 18 + diff --git a/build.sbt b/build.sbt index d42369e1..30f005da 100644 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,6 @@ name := "chen" ThisBuild / organization := "io.appthreat" -ThisBuild / version := "0.6.3" +ThisBuild / version := "1.0.0" ThisBuild / scalaVersion := "3.3.1" val cpgVersion = "1.4.22" diff --git a/codemeta.json b/codemeta.json index 43e102be..cee7d5cb 100644 --- a/codemeta.json +++ b/codemeta.json @@ -7,7 +7,7 @@ "downloadUrl": "https://github.com/AppThreat/chen", "issueTracker": "https://github.com/AppThreat/chen/issues", "name": "chen", - "version": "0.6.3", + "version": "1.0.0", "description": "Code Hierarchy Exploration Net (chen) is an advanced exploration toolkit for your application source code and its dependency hierarchy.", "applicationCategory": "code-analysis", "keywords": [ diff --git a/console/src/main/scala/io/appthreat/console/BridgeBase.scala b/console/src/main/scala/io/appthreat/console/BridgeBase.scala index 9ad70666..8a070b8b 100644 --- a/console/src/main/scala/io/appthreat/console/BridgeBase.scala +++ b/console/src/main/scala/io/appthreat/console/BridgeBase.scala @@ -41,303 +41,302 @@ case class Config( /** Base class for ReplBridge, split by topic into multiple self types. */ -trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandling with ServerHandling { - - def jProduct: JProduct - - protected def parseConfig(args: Array[String]): Config = { - val parser = new scopt.OptionParser[Config](jProduct.name) { - override def errorOnUnknownArgument = false - - note("Script execution") - - opt[Path]("script") - .action((x, c) => c.copy(scriptFile = Some(x))) - .text("path to script file: will execute and exit") - - opt[String]("param") - .valueName("param1=value1") - .unbounded() - .optional() - .action { (x, c) => - x.split("=", 2) match { - case Array(key, value) => c.copy(params = c.params + (key -> value)) - case _ => throw new IllegalArgumentException(s"unable to parse param input $x") - } - } - .text("key/value pair for main function in script - may be passed multiple times") - - opt[Path]("import") - .valueName("script1.sc") - .unbounded() - .optional() - .action((x, c) => c.copy(additionalImports = c.additionalImports :+ x)) - .text("import (and run) additional script(s) on startup - may be passed multiple times") - - opt[String]('d', "dep") - .valueName("com.michaelpollmeier:versionsort:1.0.7") - .unbounded() - .optional() - .action((x, c) => c.copy(dependencies = c.dependencies :+ x)) - .text( - "add artifacts (including transitive dependencies) for given maven coordinate to classpath - may be passed multiple times" - ) - - opt[String]('r', "repo") - .valueName("https://repository.apache.org/content/groups/public/") - .unbounded() - .optional() - .action((x, c) => c.copy(resolvers = c.resolvers :+ x)) - .text("additional repositories to resolve dependencies - may be passed multiple times") - - opt[String]("command") - .action((x, c) => c.copy(command = Some(x))) - .text("select one of multiple @main methods") - - note("Plugin Management") - - opt[String]("add-plugin") - .action((x, c) => c.copy(addPlugin = Some(x))) - .text("Plugin zip file to add to the installation") - - opt[String]("remove-plugin") - .action((x, c) => c.copy(rmPlugin = Some(x))) - .text("Name of plugin to remove from the installation") - - opt[Unit]("plugins") - .action((_, c) => c.copy(listPlugins = true)) - .text("List available plugins and layer creators") - - opt[String]("run") - .action((x, c) => c.copy(pluginToRun = Some(x))) - .text("Run layer creator. Get a list via --plugins") - - opt[String]("src") - .action((x, c) => c.copy(src = Some(x))) - .text("Source code directory to run layer creator on") - - opt[String]("language") - .action((x, c) => c.copy(language = Some(x))) - .text("Language to use in CPG creation") - - opt[Unit]("overwrite") - .action((_, c) => c.copy(overwrite = true)) - .text("Overwrite CPG if it already exists") - - opt[Unit]("store") - .action((_, c) => c.copy(store = true)) - .text("Store graph changes made by layer creator") - - note("REST server mode") - - opt[Unit]("server") - .action((_, c) => c.copy(server = true)) - .text("run as HTTP server") - - opt[String]("server-host") - .action((x, c) => c.copy(serverHost = x)) - .text("Hostname on which to expose the Chen server") - - opt[Int]("server-port") - .action((x, c) => c.copy(serverPort = x)) - .text("Port on which to expose the Chen server") - - opt[String]("server-auth-username") - .action((x, c) => c.copy(serverAuthUsername = Option(x))) - .text("Basic auth username for the Chen server") - - opt[String]("server-auth-password") - .action((x, c) => c.copy(serverAuthPassword = Option(x))) - .text("Basic auth password for the Chen server") - - note("Misc") - - arg[java.io.File]("") - .optional() - .action((x, c) => c.copy(cpgToLoad = Some(x.toScala))) - .text("Atom to load") - - opt[String]("for-input-path") - .action((x, c) => c.copy(forInputPath = Some(x))) - .text("Open CPG for given input path - overrides ") - - opt[Unit]("nocolors") - .action((_, c) => c.copy(nocolors = true)) - .text("turn off colors") - - opt[Unit]("verbose") - .action((_, c) => c.copy(verbose = true)) - .text("enable verbose output (predef, resolved dependency jars, ...)") - - opt[Int]("maxHeight") - .action((x, c) => c.copy(maxHeight = Some(x))) - .text("Maximum number lines to print before output gets truncated (default: no limit)") - - help("help") - .text("Print this help text") - } - - // note: if config is really `None` an error message would have been displayed earlier - parser.parse(args, Config()).get - } - - /** Entry point for Chen's integrated REPL and plugin manager */ - protected def run(config: Config): Unit = { - if (config.listPlugins) { - printPluginsAndLayerCreators(config) - } else if (config.addPlugin.isDefined) { - new PluginManager(InstallConfig().rootPath).add(config.addPlugin.get) - } else if (config.rmPlugin.isDefined) { - new PluginManager(InstallConfig().rootPath).rm(config.rmPlugin.get) - } else if (config.scriptFile.isDefined) { - val scriptReturn = runScript(config) - if (scriptReturn.isFailure) { - println(scriptReturn.failed.get.getMessage) - System.exit(1) - } - } else if (config.server) { - GlobalReporting.enable() - startHttpServer(config) - } else if (config.pluginToRun.isDefined) { - runPlugin(config, jProduct.name) - } else { - startInteractiveShell(config) - } - } - - protected def createPredefFile(additionalLines: Seq[String] = Nil): Path = { - val tmpFile = Files.createTempFile("chen-predef", "sc") - Files.write(tmpFile, (predefLines ++ additionalLines).asJava) - tmpFile.toAbsolutePath - } - - /** code that is executed on startup */ - protected def predefLines: Seq[String] - - protected def greeting: String - - protected def promptStr: String - - protected def onExitCode: String -} - -trait InteractiveShell { this: BridgeBase => - protected def startInteractiveShell(config: Config) = { - val replConfig = config.cpgToLoad.map { cpgFile => - "importCpg(\"" + cpgFile + "\")" - } ++ config.forInputPath.map { name => - s""" +trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandling + with ServerHandling: + + def jProduct: JProduct + + protected def parseConfig(args: Array[String]): Config = + val parser = new scopt.OptionParser[Config](jProduct.name): + override def errorOnUnknownArgument = false + + note("Script execution") + + opt[Path]("script") + .action((x, c) => c.copy(scriptFile = Some(x))) + .text("path to script file: will execute and exit") + + opt[String]("param") + .valueName("param1=value1") + .unbounded() + .optional() + .action { (x, c) => + x.split("=", 2) match + case Array(key, value) => c.copy(params = c.params + (key -> value)) + case _ => + throw new IllegalArgumentException(s"unable to parse param input $x") + } + .text("key/value pair for main function in script - may be passed multiple times") + + opt[Path]("import") + .valueName("script1.sc") + .unbounded() + .optional() + .action((x, c) => c.copy(additionalImports = c.additionalImports :+ x)) + .text( + "import (and run) additional script(s) on startup - may be passed multiple times" + ) + + opt[String]('d', "dep") + .valueName("com.michaelpollmeier:versionsort:1.0.7") + .unbounded() + .optional() + .action((x, c) => c.copy(dependencies = c.dependencies :+ x)) + .text( + "add artifacts (including transitive dependencies) for given maven coordinate to classpath - may be passed multiple times" + ) + + opt[String]('r', "repo") + .valueName("https://repository.apache.org/content/groups/public/") + .unbounded() + .optional() + .action((x, c) => c.copy(resolvers = c.resolvers :+ x)) + .text( + "additional repositories to resolve dependencies - may be passed multiple times" + ) + + opt[String]("command") + .action((x, c) => c.copy(command = Some(x))) + .text("select one of multiple @main methods") + + note("Plugin Management") + + opt[String]("add-plugin") + .action((x, c) => c.copy(addPlugin = Some(x))) + .text("Plugin zip file to add to the installation") + + opt[String]("remove-plugin") + .action((x, c) => c.copy(rmPlugin = Some(x))) + .text("Name of plugin to remove from the installation") + + opt[Unit]("plugins") + .action((_, c) => c.copy(listPlugins = true)) + .text("List available plugins and layer creators") + + opt[String]("run") + .action((x, c) => c.copy(pluginToRun = Some(x))) + .text("Run layer creator. Get a list via --plugins") + + opt[String]("src") + .action((x, c) => c.copy(src = Some(x))) + .text("Source code directory to run layer creator on") + + opt[String]("language") + .action((x, c) => c.copy(language = Some(x))) + .text("Language to use in CPG creation") + + opt[Unit]("overwrite") + .action((_, c) => c.copy(overwrite = true)) + .text("Overwrite CPG if it already exists") + + opt[Unit]("store") + .action((_, c) => c.copy(store = true)) + .text("Store graph changes made by layer creator") + + note("REST server mode") + + opt[Unit]("server") + .action((_, c) => c.copy(server = true)) + .text("run as HTTP server") + + opt[String]("server-host") + .action((x, c) => c.copy(serverHost = x)) + .text("Hostname on which to expose the Chen server") + + opt[Int]("server-port") + .action((x, c) => c.copy(serverPort = x)) + .text("Port on which to expose the Chen server") + + opt[String]("server-auth-username") + .action((x, c) => c.copy(serverAuthUsername = Option(x))) + .text("Basic auth username for the Chen server") + + opt[String]("server-auth-password") + .action((x, c) => c.copy(serverAuthPassword = Option(x))) + .text("Basic auth password for the Chen server") + + note("Misc") + + arg[java.io.File]("") + .optional() + .action((x, c) => c.copy(cpgToLoad = Some(x.toScala))) + .text("Atom to load") + + opt[String]("for-input-path") + .action((x, c) => c.copy(forInputPath = Some(x))) + .text("Open CPG for given input path - overrides ") + + opt[Unit]("nocolors") + .action((_, c) => c.copy(nocolors = true)) + .text("turn off colors") + + opt[Unit]("verbose") + .action((_, c) => c.copy(verbose = true)) + .text("enable verbose output (predef, resolved dependency jars, ...)") + + opt[Int]("maxHeight") + .action((x, c) => c.copy(maxHeight = Some(x))) + .text( + "Maximum number lines to print before output gets truncated (default: no limit)" + ) + + help("help") + .text("Print this help text") + + // note: if config is really `None` an error message would have been displayed earlier + parser.parse(args, Config()).get + end parseConfig + + /** Entry point for Chen's integrated REPL and plugin manager */ + protected def run(config: Config): Unit = + if config.listPlugins then + printPluginsAndLayerCreators(config) + else if config.addPlugin.isDefined then + new PluginManager(InstallConfig().rootPath).add(config.addPlugin.get) + else if config.rmPlugin.isDefined then + new PluginManager(InstallConfig().rootPath).rm(config.rmPlugin.get) + else if config.scriptFile.isDefined then + val scriptReturn = runScript(config) + if scriptReturn.isFailure then + println(scriptReturn.failed.get.getMessage) + System.exit(1) + else if config.server then + GlobalReporting.enable() + startHttpServer(config) + else if config.pluginToRun.isDefined then + runPlugin(config, jProduct.name) + else + startInteractiveShell(config) + + protected def createPredefFile(additionalLines: Seq[String] = Nil): Path = + val tmpFile = Files.createTempFile("chen-predef", "sc") + Files.write(tmpFile, (predefLines ++ additionalLines).asJava) + tmpFile.toAbsolutePath + + /** code that is executed on startup */ + protected def predefLines: Seq[String] + + protected def greeting: String + + protected def promptStr: String + + protected def onExitCode: String +end BridgeBase + +trait InteractiveShell: + this: BridgeBase => + protected def startInteractiveShell(config: Config) = + val replConfig = config.cpgToLoad.map { cpgFile => + "importCpg(\"" + cpgFile + "\")" + } ++ config.forInputPath.map { name => + s""" |openForInputPath(\"$name\") |""".stripMargin - } - - val predefFile = createPredefFile(replConfig.toSeq) - - replpp.InteractiveShell.run( - replpp.Config( - predefFiles = predefFile +: config.additionalImports, - nocolors = config.nocolors, - dependencies = config.dependencies, - resolvers = config.resolvers, - verbose = config.verbose, - greeting = Option(greeting), - prompt = Option(promptStr), - onExitCode = Option(onExitCode), - maxHeight = config.maxHeight - ) - ) - } - -} - -trait ScriptExecution { this: BridgeBase => - - def runScript(config: Config): Try[Unit] = { - val scriptFile = config.scriptFile.getOrElse(throw new AssertionError("no script file configured")) - if (!Files.exists(scriptFile)) { - Try(throw new AssertionError(s"given script file `$scriptFile` does not exist")) - } else { - val predefFile = createPredefFile(importCpgCode(config)) - val scriptReturn = ScriptRunner.exec( - replpp.Config( - predefFiles = predefFile +: config.additionalImports, - scriptFile = Option(scriptFile), - command = config.command, - params = config.params, - dependencies = config.dependencies, - resolvers = config.resolvers, - verbose = config.verbose + } + + val predefFile = createPredefFile(replConfig.toSeq) + + replpp.InteractiveShell.run( + replpp.Config( + predefFiles = predefFile +: config.additionalImports, + nocolors = config.nocolors, + dependencies = config.dependencies, + resolvers = config.resolvers, + verbose = config.verbose, + greeting = Option(greeting), + prompt = Option(promptStr), + onExitCode = Option(onExitCode), + maxHeight = config.maxHeight + ) ) - ) - if (config.verbose && scriptReturn.isFailure) { - println(scriptReturn.failed.get.getMessage) - } - scriptReturn - } - } - - /** For the given config, generate a list of commands to import the CPG - */ - private def importCpgCode(config: Config): List[String] = { - config.cpgToLoad.map { cpgFile => - "importCpg(\"" + cpgFile + "\")" - }.toList ++ config.forInputPath.map { name => - s""" + end startInteractiveShell +end InteractiveShell + +trait ScriptExecution: + this: BridgeBase => + + def runScript(config: Config): Try[Unit] = + val scriptFile = + config.scriptFile.getOrElse(throw new AssertionError("no script file configured")) + if !Files.exists(scriptFile) then + Try(throw new AssertionError(s"given script file `$scriptFile` does not exist")) + else + val predefFile = createPredefFile(importCpgCode(config)) + val scriptReturn = ScriptRunner.exec( + replpp.Config( + predefFiles = predefFile +: config.additionalImports, + scriptFile = Option(scriptFile), + command = config.command, + params = config.params, + dependencies = config.dependencies, + resolvers = config.resolvers, + verbose = config.verbose + ) + ) + if config.verbose && scriptReturn.isFailure then + println(scriptReturn.failed.get.getMessage) + scriptReturn + end runScript + + /** For the given config, generate a list of commands to import the CPG + */ + private def importCpgCode(config: Config): List[String] = + config.cpgToLoad.map { cpgFile => + "importCpg(\"" + cpgFile + "\")" + }.toList ++ config.forInputPath.map { name => + s""" |openForInputPath(\"$name\") |""".stripMargin - } - } -} - -trait PluginHandling { this: BridgeBase => - - /** Print a summary of the available plugins and layer creators to the terminal. - */ - protected def printPluginsAndLayerCreators(config: Config): Unit = { - println("Installed plugins:") - println("==================") - new PluginManager(InstallConfig().rootPath).listPlugins().foreach(println) - println("Available layer creators") - println() - withTemporaryScript(codeToListPlugins(), jProduct.name) { file => - runScript(config.copy(scriptFile = Some(file.path))).get - } - } - - private def codeToListPlugins(): String = { - """ + } +end ScriptExecution + +trait PluginHandling: + this: BridgeBase => + + /** Print a summary of the available plugins and layer creators to the terminal. + */ + protected def printPluginsAndLayerCreators(config: Config): Unit = + println("Installed plugins:") + println("==================") + new PluginManager(InstallConfig().rootPath).listPlugins().foreach(println) + println("Available layer creators") + println() + withTemporaryScript(codeToListPlugins(), jProduct.name) { file => + runScript(config.copy(scriptFile = Some(file.path))).get + } + + private def codeToListPlugins(): String = + """ |println(run) | |""".stripMargin - } - - /** Run plugin by generating a temporary script based on the given config and execute the script */ - protected def runPlugin(config: Config, productName: String): Unit = { - if (config.src.isEmpty) { - println("You must supply a source directory with the --src flag") - return - } - val code = loadOrCreateCpg(config, productName) - withTemporaryScript(code, productName) { file => - runScript(config.copy(scriptFile = Some(file.path))).get - } - } - - /** Create a command that loads an existing CPG or creates it, based on the given `config`. - */ - private def loadOrCreateCpg(config: Config, productName: String): String = { - - val bundleName = config.pluginToRun.get - val src = better.files.File(config.src.get).path.toAbsolutePath.toString - val language = languageFromConfig(config, src) - - val storeCode = if (config.store) { "save" } - else { "" } - val runDataflow = if (productName == "ocular") { "run.dataflow" } - else { "run.ossdataflow" } - val argsString = argsStringFromConfig(config) - - s""" + + /** Run plugin by generating a temporary script based on the given config and execute the script + */ + protected def runPlugin(config: Config, productName: String): Unit = + if config.src.isEmpty then + println("You must supply a source directory with the --src flag") + return + val code = loadOrCreateCpg(config, productName) + withTemporaryScript(code, productName) { file => + runScript(config.copy(scriptFile = Some(file.path))).get + } + + /** Create a command that loads an existing CPG or creates it, based on the given `config`. + */ + private def loadOrCreateCpg(config: Config, productName: String): String = + + val bundleName = config.pluginToRun.get + val src = better.files.File(config.src.get).path.toAbsolutePath.toString + val language = languageFromConfig(config, src) + + val storeCode = if config.store then "save" + else "" + val runDataflow = if productName == "ocular" then "run.dataflow" + else "run.ossdataflow" + val argsString = argsStringFromConfig(config) + + s""" | if (${config.overwrite} || !workspace.projectExists("$src")) { | workspace.projects | .filter(_.inputPath == "$src") @@ -352,66 +351,60 @@ trait PluginHandling { this: BridgeBase => | run.$bundleName | $storeCode |""".stripMargin + end loadOrCreateCpg + + private def languageFromConfig(config: Config, src: String): String = + config.language.getOrElse( + io.appthreat.console.cpgcreation + .guessLanguage(src) + .map { + case Languages.C | Languages.NEWC => "c" + case Languages.JAVA => "jvm" + case Languages.JAVASRC => "java" + case lang => lang.toLowerCase + } + .getOrElse("c") + ) - } - - private def languageFromConfig(config: Config, src: String): String = { - config.language.getOrElse( - io.appthreat.console.cpgcreation - .guessLanguage(src) - .map { - case Languages.C | Languages.NEWC => "c" - case Languages.JAVA => "jvm" - case Languages.JAVASRC => "java" - case lang => lang.toLowerCase + private def argsStringFromConfig(config: Config): String = + config.frontendArgs match + case Array() => "" + case args => + val quotedArgs = args.map { arg => + "\"" ++ arg ++ "\"" + } + val argsString = quotedArgs.mkString(", ") + s", args=List($argsString)" + + private def withTemporaryScript(code: String, prefix: String)(f: File => Unit): Unit = + File.usingTemporaryDirectory(prefix + "-bundle") { dir => + val file = dir / "script.sc" + file.write(code) + f(file) } - .getOrElse("c") - ) - } - - private def argsStringFromConfig(config: Config): String = { - config.frontendArgs match { - case Array() => "" - case args => - val quotedArgs = args.map { arg => - "\"" ++ arg ++ "\"" - } - val argsString = quotedArgs.mkString(", ") - s", args=List($argsString)" - } - } - - private def withTemporaryScript(code: String, prefix: String)(f: File => Unit): Unit = { - File.usingTemporaryDirectory(prefix + "-bundle") { dir => - val file = dir / "script.sc" - file.write(code) - f(file) - } - } - -} - -trait ServerHandling { this: BridgeBase => - - protected def startHttpServer(config: Config): Unit = { - val predefFile = createPredefFile(Nil) - - val baseConfig = replpp.Config( - predefFiles = predefFile +: config.additionalImports, - dependencies = config.dependencies, - resolvers = config.resolvers, - verbose = false - ) - - replpp.server.ReplServer.startHttpServer( - replpp.server.Config( - baseConfig, - serverHost = config.serverHost, - serverPort = config.serverPort, - serverAuthUsername = config.serverAuthUsername, - serverAuthPassword = config.serverAuthPassword - ) - ) - } - -} +end PluginHandling + +trait ServerHandling: + this: BridgeBase => + + protected def startHttpServer(config: Config): Unit = + val predefFile = createPredefFile(Nil) + + val baseConfig = replpp.Config( + predefFiles = predefFile +: config.additionalImports, + dependencies = config.dependencies, + resolvers = config.resolvers, + verbose = false + ) + + replpp.server.ReplServer.startHttpServer( + replpp.server.Config( + baseConfig, + serverHost = config.serverHost, + serverPort = config.serverPort, + serverAuthUsername = config.serverAuthUsername, + serverAuthPassword = config.serverAuthPassword + ) + ) + end startHttpServer +end ServerHandling diff --git a/console/src/main/scala/io/appthreat/console/Commit.scala b/console/src/main/scala/io/appthreat/console/Commit.scala index cf9fb3e0..0e8a892f 100644 --- a/console/src/main/scala/io/appthreat/console/Commit.scala +++ b/console/src/main/scala/io/appthreat/console/Commit.scala @@ -4,29 +4,23 @@ import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} import overflowdb.BatchedUpdate.DiffGraphBuilder -object Commit { - val overlayName: String = "commit" - val description: String = "Apply current custom diffgraph" - def defaultOpts = new CommitOptions(new DiffGraphBuilder) -} +object Commit: + val overlayName: String = "commit" + val description: String = "Apply current custom diffgraph" + def defaultOpts = new CommitOptions(new DiffGraphBuilder) class CommitOptions(var diffGraphBuilder: DiffGraphBuilder) extends LayerCreatorOptions -class Commit(opts: CommitOptions) extends LayerCreator { +class Commit(opts: CommitOptions) extends LayerCreator: - override val overlayName: String = Commit.overlayName - override val description: String = Commit.description - override val storeOverlayName: Boolean = false + override val overlayName: String = Commit.overlayName + override val description: String = Commit.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val pass: CpgPass = new CpgPass(context.cpg) { - override val name = "commit" - override def run(builder: DiffGraphBuilder): Unit = { - builder.absorb(opts.diffGraphBuilder) - } - } - runPass(pass, context, storeUndoInfo) - opts.diffGraphBuilder = new DiffGraphBuilder - } - -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val pass: CpgPass = new CpgPass(context.cpg): + override val name = "commit" + override def run(builder: DiffGraphBuilder): Unit = + builder.absorb(opts.diffGraphBuilder) + runPass(pass, context, storeUndoInfo) + opts.diffGraphBuilder = new DiffGraphBuilder diff --git a/console/src/main/scala/io/appthreat/console/Console.scala b/console/src/main/scala/io/appthreat/console/Console.scala index 86ba2a35..3f9d2c14 100644 --- a/console/src/main/scala/io/appthreat/console/Console.scala +++ b/console/src/main/scala/io/appthreat/console/Console.scala @@ -22,64 +22,66 @@ import scala.sys.process.Process import scala.util.control.NoStackTrace import scala.util.{Failure, Success, Try} -class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.currentWorkingDirectory) - extends Reporting { - - import Console.* - - private val _config = new ConsoleConfig() - def config: ConsoleConfig = _config - def console: Console[T] = this - - protected var workspaceManager: WorkspaceManager[T] = _ - switchWorkspace(baseDir.path.resolve("workspace").toString) - protected def workspacePathName: String = workspaceManager.getPath - - private val nameOfCpgInProject = "cpg.bin" - implicit val resolver: ICallResolver = NoResolve - implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder - - implicit val pyGlobal: me.shadaj.scalapy.py.Dynamic.global.type = py.Dynamic.global - var richTableLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richTreeLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richProgressLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richConsole: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richAvailable = true - try { - richTableLib = py.module("rich.table") - richTreeLib = py.module("rich.tree") - richProgressLib = py.module("rich.progress") - richConsole = py.module("chenpy.logger").console - } catch { - case _: Exception => richAvailable = false - } - - implicit object ConsoleImageViewer extends ImageViewer { - def view(imagePathStr: String): Try[String] = { - // We need to copy the file as the original one is only temporary - // and gets removed immediately after running this viewer instance asynchronously via .run(). - val tmpFile = File(imagePathStr).copyTo(File.newTemporaryFile(suffix = ".svg"), overwrite = true) - tmpFile.deleteOnExit(swallowIOExceptions = true) - Try { - val command = if (scala.util.Properties.isWin) { Seq("cmd.exe", "/C", config.tools.imageViewer) } - else { Seq(config.tools.imageViewer) } - Process(command :+ tmpFile.path.toAbsolutePath.toString).run() - } match { - case Success(_) => - // We never handle the actual result anywhere. - // Hence, we just pass a success message. - Success(s"Running viewer for '$tmpFile' finished.") - case Failure(exc) => - System.err.println("Executing image viewer failed. Is it installed? ") - System.err.println(exc) - Failure(exc) - } - } - } - - @Doc( - info = "Access to the workspace directory", - longInfo = """ +class Console[T <: Project]( + loader: WorkspaceLoader[T], + baseDir: File = File.currentWorkingDirectory +) extends Reporting: + + import Console.* + + private val _config = new ConsoleConfig() + def config: ConsoleConfig = _config + def console: Console[T] = this + + protected var workspaceManager: WorkspaceManager[T] = _ + switchWorkspace(baseDir.path.resolve("workspace").toString) + protected def workspacePathName: String = workspaceManager.getPath + + private val nameOfCpgInProject = "cpg.bin" + implicit val resolver: ICallResolver = NoResolve + implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder + + implicit val pyGlobal: me.shadaj.scalapy.py.Dynamic.global.type = py.Dynamic.global + var richTableLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richTreeLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richProgressLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richConsole: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richAvailable = true + try + richTableLib = py.module("rich.table") + richTreeLib = py.module("rich.tree") + richProgressLib = py.module("rich.progress") + richConsole = py.module("chenpy.logger").console + catch + case _: Exception => richAvailable = false + + implicit object ConsoleImageViewer extends ImageViewer: + def view(imagePathStr: String): Try[String] = + // We need to copy the file as the original one is only temporary + // and gets removed immediately after running this viewer instance asynchronously via .run(). + val tmpFile = + File(imagePathStr).copyTo(File.newTemporaryFile(suffix = ".svg"), overwrite = true) + tmpFile.deleteOnExit(swallowIOExceptions = true) + Try { + val command = if scala.util.Properties.isWin then + Seq("cmd.exe", "/C", config.tools.imageViewer) + else Seq(config.tools.imageViewer) + Process(command :+ tmpFile.path.toAbsolutePath.toString).run() + } match + case Success(_) => + // We never handle the actual result anywhere. + // Hence, we just pass a success message. + Success(s"Running viewer for '$tmpFile' finished.") + case Failure(exc) => + System.err.println("Executing image viewer failed. Is it installed? ") + System.err.println(exc) + Failure(exc) + end view + end ConsoleImageViewer + + @Doc( + info = "Access to the workspace directory", + longInfo = """ |All auditing projects are stored in a workspace directory, and `workspace` |provides programmatic access to this directory. Entering `workspace` provides |a list of all projects, indicating which code the project makes accessible, @@ -108,13 +110,13 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |workspace directory | |""", - example = "workspace" - ) - def workspace: WorkspaceManager[T] = workspaceManager + example = "workspace" + ) + def workspace: WorkspaceManager[T] = workspaceManager - @Doc( - info = "Close current workspace and open a different one", - longInfo = """ | By default, the workspace in $INSTALL_DIR/workspace is used. + @Doc( + info = "Close current workspace and open a different one", + longInfo = """ | By default, the workspace in $INSTALL_DIR/workspace is used. | This method allows specifying a different workspace directory | via the `pathName` parameter. | Before changing the workspace, the current workspace will be @@ -122,24 +124,22 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur | If `pathName` points to a non-existing directory, then a new | workspace is first created. |""" - ) - def switchWorkspace(pathName: String): Unit = { - if (workspaceManager != null) { - report("Saving current workspace before changing workspace") - workspaceManager.projects.foreach { p => - p.close - } - } - workspaceManager = new WorkspaceManager[T](pathName, loader) - } - - @Doc(info = "Currently active project", example = "project") - def project: T = - workspace.projectByCpg(cpg).getOrElse(throw new RuntimeException("No active project")) - - @Doc( - info = "CPG of the active project", - longInfo = """ + ) + def switchWorkspace(pathName: String): Unit = + if workspaceManager != null then + report("Saving current workspace before changing workspace") + workspaceManager.projects.foreach { p => + p.close + } + workspaceManager = new WorkspaceManager[T](pathName, loader) + + @Doc(info = "Currently active project", example = "project") + def project: T = + workspace.projectByCpg(cpg).getOrElse(throw new RuntimeException("No active project")) + + @Doc( + info = "CPG of the active project", + longInfo = """ |Upon importing code, a project is created that holds |an intermediate representation called `Code Property Graph`. This |graph is a composition of low-level program representations such @@ -153,34 +153,31 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |`cpg.method.l` lists all methods, while `cpg.finding.l` lists all findings |of potentially vulnerable code. |""", - example = "cpg.method.l" - ) - implicit def cpg: Cpg = workspace.cpg - def atom: Cpg = workspace.cpg - - /** All cpgs loaded in the workspace - */ - def cpgs: Iterator[Cpg] = { - if (workspace.projects.lastOption.isEmpty) { - Iterator() - } else { - val activeProjectName = project.name - (workspace.projects.filter(_.cpg.isDefined).iterator.flatMap { project => - open(project.name) - Some(project.cpg) - } ++ Iterator({ open(activeProjectName); None })).flatten - } - } - - // Provide `.l` on iterators, specifically so - // that `cpgs.flatMap($query).l` is possible - implicit class ItExtend[X](it: Iterator[X]) { - def l: List[X] = it.toList - } - - @Doc( - info = "Open project by name", - longInfo = """ + example = "cpg.method.l" + ) + implicit def cpg: Cpg = workspace.cpg + def atom: Cpg = workspace.cpg + + /** All cpgs loaded in the workspace + */ + def cpgs: Iterator[Cpg] = + if workspace.projects.lastOption.isEmpty then + Iterator() + else + val activeProjectName = project.name + (workspace.projects.filter(_.cpg.isDefined).iterator.flatMap { project => + open(project.name) + Some(project.cpg) + } ++ Iterator({ open(activeProjectName); None })).flatten + + // Provide `.l` on iterators, specifically so + // that `cpgs.flatMap($query).l` is possible + implicit class ItExtend[X](it: Iterator[X]): + def l: List[X] = it.toList + + @Doc( + info = "Open project by name", + longInfo = """ |open([projectName]) | |Opens the project named `name` and make it the active project. @@ -192,18 +189,18 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |can be queried via `cpg`. Returns an optional reference to the |project, which is empty on error. |""", - example = """open("projectName")""" - ) - def open(name: String): Option[Project] = { - val projectName = fixProjectNameAndComplainOnFix(name) - workspace.openProject(projectName).map { project => - project - } - } - - @Doc( - info = "Open project for input path", - longInfo = """ + example = """open("projectName")""" + ) + def open(name: String): Option[Project] = + val projectName = fixProjectNameAndComplainOnFix(name) + workspace.openProject(projectName).map { project => + project + } + end open + + @Doc( + info = "Open project for input path", + longInfo = """ |openForInputPath([input-path]) | |Opens the project of the CPG generated for the input path `input-path`. @@ -212,79 +209,73 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |can be queried via `cpg`. Returns an optional reference to the |project, which is empty on error. |""" - ) - def openForInputPath(inputPath: String): Option[Project] = { - val absInputPath = File(inputPath).path.toAbsolutePath.toString - workspace.projects - .filter(x => x.inputPath == absInputPath) - .map(_.name) - .map(open) - .headOption - .flatten - } - - /** Open the active project - */ - def open: Option[Project] = { - workspace.projects.lastOption.flatMap { p => - open(p.name) - } - } - - /** Delete project from disk and remove it from the workspace manager. Returns the (now invalid) project. - * @param name - * the name of the project - */ - @Doc(info = "Close and remove project from disk", example = "delete(projectName)") - def delete(name: String): Option[Unit] = { - workspaceManager.getActiveProject.foreach(_.cpg.foreach(_.close())) - defaultProjectNameIfEmpty(name).flatMap(workspace.deleteProject) - } - - @Doc(info = "Exit the REPL") - def exit: Unit = { - workspace.projects.foreach(_.close) - System.exit(0) - } - - /** Delete the active project - */ - def delete: Option[Unit] = delete("") - - protected def defaultProjectNameIfEmpty(name: String): Option[String] = { - if (name.isEmpty) { - val projectNameOpt = workspace.projectByCpg(cpg).map(_.name) - if (projectNameOpt.isEmpty) { - report("Fatal: cannot find project for active CPG") - } - projectNameOpt - } else { - Some(fixProjectNameAndComplainOnFix(name)) - } - } - - @Doc( - info = "Write all changes to disk", - longInfo = """ + ) + def openForInputPath(inputPath: String): Option[Project] = + val absInputPath = File(inputPath).path.toAbsolutePath.toString + workspace.projects + .filter(x => x.inputPath == absInputPath) + .map(_.name) + .map(open) + .headOption + .flatten + end openForInputPath + + /** Open the active project + */ + def open: Option[Project] = + workspace.projects.lastOption.flatMap { p => + open(p.name) + } + + /** Delete project from disk and remove it from the workspace manager. Returns the (now invalid) + * project. + * @param name + * the name of the project + */ + @Doc(info = "Close and remove project from disk", example = "delete(projectName)") + def delete(name: String): Option[Unit] = + workspaceManager.getActiveProject.foreach(_.cpg.foreach(_.close())) + defaultProjectNameIfEmpty(name).flatMap(workspace.deleteProject) + + @Doc(info = "Exit the REPL") + def exit: Unit = + workspace.projects.foreach(_.close) + System.exit(0) + + /** Delete the active project + */ + def delete: Option[Unit] = delete("") + + protected def defaultProjectNameIfEmpty(name: String): Option[String] = + if name.isEmpty then + val projectNameOpt = workspace.projectByCpg(cpg).map(_.name) + if projectNameOpt.isEmpty then + report("Fatal: cannot find project for active CPG") + projectNameOpt + else + Some(fixProjectNameAndComplainOnFix(name)) + + @Doc( + info = "Write all changes to disk", + longInfo = """ |Close and reopen all loaded CPGs. This ensures |that changes have been flushed to disk. | |Returns list of affected projects |""", - example = "save" - ) - def save: List[Project] = { - report("Saving graphs on disk. This may take a while.") - workspace.projects.collect { - case p: Project if p.cpg.isDefined => - p.close - workspace.openProject(p.name) - }.flatten - } - - @Doc( - info = "Create new project from code", - longInfo = """ + example = "save" + ) + def save: List[Project] = + report("Saving graphs on disk. This may take a while.") + workspace.projects.collect { + case p: Project if p.cpg.isDefined => + p.close + workspace.openProject(p.name) + }.flatten + + @Doc( + info = "Create new project from code", + longInfo = """ |importCode(, [projectName], [namespaces], [language]) | |Type `importCode` alone to get a list of all supported languages @@ -321,13 +312,13 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |the filename found and possibly by looking into the file/directory. | |""", - example = """importCode("git url or path")""" - ) - def importCode = new ImportCode(this) + example = """importCode("git url or path")""" + ) + def importCode = new ImportCode(this) - @Doc( - info = "Create new project from existing atom", - longInfo = """ + @Doc( + info = "Create new project from existing atom", + longInfo = """ |importAtom(, [projectName]) | |Import an existing atom. @@ -341,15 +332,15 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |is omitted, the path is derived from `inputPath` | |""", - example = """importAtom("app.atom")""" - ) - def importAtom(inputPath: String, projectName: String = ""): Unit = { - importCpg(inputPath, projectName, false) - summary - } - @Doc( - info = "Create new project from existing CPG", - longInfo = """ + example = """importAtom("app.atom")""" + ) + def importAtom(inputPath: String, projectName: String = ""): Unit = + importCpg(inputPath, projectName, false) + summary + end importAtom + @Doc( + info = "Create new project from existing CPG", + longInfo = """ |importCpg(, [projectName], [enhance]) | |Import an existing CPG. The CPG is stored as part @@ -368,122 +359,129 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |enhance: run default overlays and post-processing passes. Defaults to `true`. |Pass `enhance=false` to disable the enhancements. |""", - example = """importCpg("app.atom")""" - ) - def importCpg(inputPath: String, projectName: String = "", enhance: Boolean = true): Option[Cpg] = { - val name = - Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath(inputPath, workspace)) - val cpgFile = File(inputPath) - - if (!cpgFile.exists) { - report(s"CPG at $inputPath does not exist. Bailing out.") - return None - } - - val pathToProject = workspace.createProject(inputPath, name) - val cpgDestinationPathOpt = pathToProject.map(_.resolve(nameOfCpgInProject)) - - if (cpgDestinationPathOpt.isEmpty) { - report(s"Error creating project for input path: `$inputPath`") - return None - } - - val cpgDestinationPath = cpgDestinationPathOpt.get - - if (CpgLoader.isLegacyCpg(cpgFile)) { - report("You have provided a legacy proto CPG. Attempting conversion.") - try { - CpgConverter.convertProtoCpgToOverflowDb(cpgFile.path.toString, cpgDestinationPath.toString) - } catch { - case exc: Exception => - report("Error converting legacy CPG: " + exc.getMessage) - return None - } - } else { - cpgFile.copyTo(cpgDestinationPath, overwrite = true) - } - - val cpgOpt = open(name).flatMap(_.cpg) - - if (cpgOpt.isEmpty) { - workspace.deleteProject(name) - } - - cpgOpt - .filter(_.metaData.hasNext) - .foreach { cpg => - if (enhance) applyDefaultOverlays(cpg) - applyPostProcessingPasses(cpg) - } - cpgOpt - } - - @Doc( - info = "Close project by name", - longInfo = """|Close project. Resources are freed but the project remains on disk. + example = """importCpg("app.atom")""" + ) + def importCpg( + inputPath: String, + projectName: String = "", + enhance: Boolean = true + ): Option[Cpg] = + val name = + Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( + inputPath, + workspace + )) + val cpgFile = File(inputPath) + + if !cpgFile.exists then + report(s"CPG at $inputPath does not exist. Bailing out.") + return None + + val pathToProject = workspace.createProject(inputPath, name) + val cpgDestinationPathOpt = pathToProject.map(_.resolve(nameOfCpgInProject)) + + if cpgDestinationPathOpt.isEmpty then + report(s"Error creating project for input path: `$inputPath`") + return None + + val cpgDestinationPath = cpgDestinationPathOpt.get + + if CpgLoader.isLegacyCpg(cpgFile) then + report("You have provided a legacy proto CPG. Attempting conversion.") + try + CpgConverter.convertProtoCpgToOverflowDb( + cpgFile.path.toString, + cpgDestinationPath.toString + ) + catch + case exc: Exception => + report("Error converting legacy CPG: " + exc.getMessage) + return None + else + cpgFile.copyTo(cpgDestinationPath, overwrite = true) + + val cpgOpt = open(name).flatMap(_.cpg) + + if cpgOpt.isEmpty then + workspace.deleteProject(name) + + cpgOpt + .filter(_.metaData.hasNext) + .foreach { cpg => + if enhance then applyDefaultOverlays(cpg) + applyPostProcessingPasses(cpg) + } + cpgOpt + end importCpg + + @Doc( + info = "Close project by name", + longInfo = """|Close project. Resources are freed but the project remains on disk. |The project remains active, that is, calling `cpg` now raises an |exception. A different project can now be activated using `open`. |""", - example = "close(projectName)" - ) - def close(name: String): Option[Project] = defaultProjectNameIfEmpty(name).flatMap(workspace.closeProject) - - def close: Option[Project] = close("") - - /** Close the project and open it again. - * - * @param name - * the name of the project - */ - def reload(name: String): Option[Project] = { - close(name).flatMap(p => open(p.name)) - } - - @Doc( - info = "Display summary information", - longInfo = """|Displays summary about the loaded atom such as the number of files, methods, annotations etc. + example = "close(projectName)" + ) + def close(name: String): Option[Project] = + defaultProjectNameIfEmpty(name).flatMap(workspace.closeProject) + + def close: Option[Project] = close("") + + /** Close the project and open it again. + * + * @param name + * the name of the project + */ + def reload(name: String): Option[Project] = + close(name).flatMap(p => open(p.name)) + + @Doc( + info = "Display summary information", + longInfo = + """|Displays summary about the loaded atom such as the number of files, methods, annotations etc. |Requires the python modules to be installed. |""", - example = "summary" - ) - def summary(as_text: Boolean): String = { - val table = richTableLib.Table(title = "Atom Summary") - table.add_column("Node Type") - table.add_column("Count") - table.add_row("Files", "" + atom.file.size) - table.add_row("Methods", "" + atom.method.size) - table.add_row("Annotations", "" + atom.annotation.size) - table.add_row("Imports", "" + atom.imports.size) - table.add_row("Literals", "" + atom.literal.size) - table.add_row("Config Files", "" + atom.configFile.size) - val appliedOverlays = Overlays.appliedOverlays(atom) - if (appliedOverlays.nonEmpty) table.add_row("Overlays", "" + appliedOverlays.size) - richConsole.clear() - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - def summary: String = summary(as_text = false) - - @Doc( - info = "List files", - longInfo = """|Lists the files from the loaded atom. + example = "summary" + ) + def summary(as_text: Boolean): String = + val table = richTableLib.Table(title = "Atom Summary") + table.add_column("Node Type") + table.add_column("Count") + table.add_row("Files", "" + atom.file.size) + table.add_row("Methods", "" + atom.method.size) + table.add_row("Annotations", "" + atom.annotation.size) + table.add_row("Imports", "" + atom.imports.size) + table.add_row("Literals", "" + atom.literal.size) + table.add_row("Config Files", "" + atom.configFile.size) + val appliedOverlays = Overlays.appliedOverlays(atom) + if appliedOverlays.nonEmpty then table.add_row("Overlays", "" + appliedOverlays.size) + richConsole.clear() + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end summary + def summary: String = summary(as_text = false) + + @Doc( + info = "List files", + longInfo = """|Lists the files from the loaded atom. |Requires the python modules to be installed. |""", - example = "files" - ) - def files(title: String = "Files", as_text: Boolean): String = { - val table = richTableLib.Table(title = title, highlight = true) - table.add_column("File Name") - table.add_column("Method Count") - atom.file.whereNot(_.name("")).foreach { f => table.add_row(f.name, "" + f.method.size) } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - def files: String = files("Files", as_text = false) - - @Doc( - info = "List methods", - longInfo = """|Lists the methods by files from the loaded atom. + example = "files" + ) + def files(title: String = "Files", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true) + table.add_column("File Name") + table.add_column("Method Count") + atom.file.whereNot(_.name("")).foreach { f => + table.add_row(f.name, "" + f.method.size) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + def files: String = files("Files", as_text = false) + + @Doc( + info = "List methods", + longInfo = """|Lists the methods by files from the loaded atom. |Requires the python modules to be installed. | |Parameters: @@ -491,323 +489,330 @@ class Console[T <: Project](loader: WorkspaceLoader[T], baseDir: File = File.cur |title: Title for the table. Default Methods. |tree: Display as a tree instead of table |""", - example = "methods('Methods', includeCalls=true, tree=true)" - ) - def methods( - title: String = "Methods", - includeCalls: Boolean = false, - tree: Boolean = false, - as_text: Boolean = false - ): String = { - if (tree) { - val rootTree = richTreeLib.Tree(title, highlight = true) - atom.file.whereNot(_.name("")).foreach { f => - val childTree = richTreeLib.Tree(f.name, highlight = true) - f.method.foreach(m => { - val mtree = childTree.add(m.fullName) - if (includeCalls) - m.call - .filterNot(_.name.startsWith(" - mtree - .add( - c.methodFullName + (if (c.callee(NoResolve).head.isExternal) " :right_arrow_curving_up:" else "") - ) - ) - }) - rootTree.add(childTree) - } - richConsole.print(rootTree) - if (as_text) richConsole.export_text().as[String] else "" - } else { - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Methods") - atom.file.whereNot(_.name("")).foreach { f => - table.add_row(f.name, f.method.fullName.l.mkString("\n")) - } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - } - def methods: String = methods("Methods", as_text = false) - - @Doc( - info = "List annotations", - longInfo = """|Lists the method annotations by files from the loaded atom. + example = "methods('Methods', includeCalls=true, tree=true)" + ) + def methods( + title: String = "Methods", + includeCalls: Boolean = false, + tree: Boolean = false, + as_text: Boolean = false + ): String = + if tree then + val rootTree = richTreeLib.Tree(title, highlight = true) + atom.file.whereNot(_.name("")).foreach { f => + val childTree = richTreeLib.Tree(f.name, highlight = true) + f.method.foreach(m => + val mtree = childTree.add(m.fullName) + if includeCalls then + m.call + .filterNot(_.name.startsWith(" + mtree + .add( + c.methodFullName + (if c.callee(NoResolve).head.isExternal + then " :right_arrow_curving_up:" + else "") + ) + ) + ) + rootTree.add(childTree) + } + richConsole.print(rootTree) + if as_text then richConsole.export_text().as[String] else "" + else + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Methods") + atom.file.whereNot(_.name("")).foreach { f => + table.add_row(f.name, f.method.fullName.l.mkString("\n")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + def methods: String = methods("Methods", as_text = false) + + @Doc( + info = "List annotations", + longInfo = """|Lists the method annotations by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "annotations" - ) - def annotations(title: String = "Annotations", as_text: Boolean): String = { - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Methods") - table.add_column("Annotations") - atom.file.whereNot(_.name("")).method.filter(_.annotation.nonEmpty).foreach { m => - table.add_row(m.location.filename, m.fullName, m.annotation.fullName.l.mkString("\n")) - } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - - def annotations: String = annotations("Annotations", as_text = false) - - @Doc( - info = "List imports", - longInfo = """|Lists the imports by files from the loaded atom. + example = "annotations" + ) + def annotations(title: String = "Annotations", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Methods") + table.add_column("Annotations") + atom.file.whereNot(_.name("")).method.filter(_.annotation.nonEmpty).foreach { m => + table.add_row(m.location.filename, m.fullName, m.annotation.fullName.l.mkString("\n")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + + def annotations: String = annotations("Annotations", as_text = false) + + @Doc( + info = "List imports", + longInfo = """|Lists the imports by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "imports" - ) - def imports(title: String = "Imports", as_text: Boolean): String = { - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Import") - atom.imports.foreach { i => - table.add_row(i.file.name.l.mkString("\n"), i.importedEntity.getOrElse("")) - } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - - def imports: String = imports("Imports", as_text = false) - - @Doc( - info = "List declarations", - longInfo = """|Lists the declarations by files from the loaded atom. + example = "imports" + ) + def imports(title: String = "Imports", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Import") + atom.imports.foreach { i => + table.add_row(i.file.name.l.mkString("\n"), i.importedEntity.getOrElse("")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + + def imports: String = imports("Imports", as_text = false) + + @Doc( + info = "List declarations", + longInfo = """|Lists the declarations by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "declarations" - ) - def declarations(title: String = "Declarations", as_text: Boolean): String = { - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Declarations") - atom.file.whereNot(_.name("")).foreach { f => - val dec: Set[Declaration] = - (f.assignment.argument(1).filterNot(_.code == "this").isIdentifier.refsTo ++ f.method.parameter - .filterNot(_.code == "this") - .filter(_.typeFullName != "ANY")).toSet - table.add_row(f.name, dec.name.toSet.mkString("\n")) - } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - - def declarations: String = declarations("Declarations", as_text = false) - - @Doc( - info = "List sensitive literals", - longInfo = """|Lists the sensitive literals by files from the loaded atom. + example = "declarations" + ) + def declarations(title: String = "Declarations", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Declarations") + atom.file.whereNot(_.name("")).foreach { f => + val dec: Set[Declaration] = + (f.assignment.argument(1).filterNot( + _.code == "this" + ).isIdentifier.refsTo ++ f.method.parameter + .filterNot(_.code == "this") + .filter(_.typeFullName != "ANY")).toSet + table.add_row(f.name, dec.name.toSet.mkString("\n")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end declarations + + def declarations: String = declarations("Declarations", as_text = false) + + @Doc( + info = "List sensitive literals", + longInfo = """|Lists the sensitive literals by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "sensitive" - ) - def sensitive( - title: String = "Sensitive Literals", - pattern: String = "(secret|password|token|key|admin|root)", - as_text: Boolean = false - ): String = { - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Sensitive Literals") - atom.file.whereNot(_.name("")).foreach { f => - val slits: Set[Literal] = - f.assignment.where(_.argument.order(1).code(s"(?i).*${pattern}.*")).argument.order(2).isLiteral.toSet - table.add_row(f.name, if (slits.nonEmpty) slits.code.mkString("\n") else "N/A") - } - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - - def sensitive: String = sensitive("Sensitive Literals", as_text = false) - - @Doc( - info = "Show graph edit distance from the source method to the comparison methods", - longInfo = """|Compute graph edit distance from the source method to the comparison methods. - |""", - example = "distance(source method iterator, comparison method iterators)" - ) - def distance(sourceTrav: Iterator[Method], sourceTravs: Iterator[Method]*): Seq[Double] = { - val first_method = new MethodTraversal(sourceTrav.iterator).gml - sourceTravs.map { compareTrav => - val second_method = new MethodTraversal(compareTrav.iterator).gml - Torch.edit_distance(first_method, second_method) - } - } - - case class MethodDistance(filename: String, fullName: String, editDistance: Double) - - @Doc( - info = "Show methods similar to the given method", - longInfo = """|List methods to similar to the one based on graph edit distance. + example = "sensitive" + ) + def sensitive( + title: String = "Sensitive Literals", + pattern: String = "(secret|password|token|key|admin|root)", + as_text: Boolean = false + ): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Sensitive Literals") + atom.file.whereNot(_.name("")).foreach { f => + val slits: Set[Literal] = + f.assignment.where(_.argument.order(1).code(s"(?i).*${pattern}.*")).argument.order( + 2 + ).isLiteral.toSet + table.add_row(f.name, if slits.nonEmpty then slits.code.mkString("\n") else "N/A") + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end sensitive + + def sensitive: String = sensitive("Sensitive Literals", as_text = false) + + @Doc( + info = "Show graph edit distance from the source method to the comparison methods", + longInfo = """|Compute graph edit distance from the source method to the comparison methods. |""", - example = "showSimilar(method full name)" - ) - def showSimilar( - methodFullName: String, - comparePattern: String = "", - upper_bound: Int = 500, - timeout: Int = 5, - as_text: Boolean = false - ): String = { - val table = - richTableLib.Table(title = s"Similarity analysis for `${methodFullName}`", highlight = true, show_lines = true) - val progress = richProgressLib.Progress(transient = true) - val first_method = atom.method.fullNameExact(methodFullName).gml - table.add_column("File Name") - table.add_column("Method Name") - table.add_column("Edit Distance") - val methodDistances = mutable.ArrayBuffer[MethodDistance]() - val base = if (comparePattern.nonEmpty) atom.method.fullName(s".*${comparePattern}.*") else atom.method.internal - base.whereNot(_.fullNameExact(methodFullName)).foreach { method => - py.`with`(progress) { mprogress => - val task = mprogress.add_task(s"Analyzing ${method.fullName}", start = false, total = 100) - val edit_distance = - Torch.edit_distance(first_method, method.iterator.gml, upper_bound = upper_bound, timeout = timeout) - if (edit_distance != -1) { - methodDistances += MethodDistance(method.location.filename, method.fullName, edit_distance) + example = "distance(source method iterator, comparison method iterators)" + ) + def distance(sourceTrav: Iterator[Method], sourceTravs: Iterator[Method]*): Seq[Double] = + val first_method = new MethodTraversal(sourceTrav.iterator).gml + sourceTravs.map { compareTrav => + val second_method = new MethodTraversal(compareTrav.iterator).gml + Torch.edit_distance(first_method, second_method) } - mprogress.stop_task(task) - mprogress.update(task, completed = 100) - } - } - methodDistances.sortInPlaceBy[Double](x => x.editDistance) - methodDistances.foreach(row => table.add_row(row.filename, row.fullName, "" + row.editDistance)) - richConsole.print(table) - if (as_text) richConsole.export_text().as[String] else "" - } - - def printDashes(count: Int) = { - var tabStr = "+--- " - var i = 0 - while (i < count) { - tabStr = "| " + tabStr - i += 1 - } - tabStr - } - - @Doc( - info = "Show call tree for the given method", - longInfo = """|Show the call tree for the given method. + + case class MethodDistance(filename: String, fullName: String, editDistance: Double) + + @Doc( + info = "Show methods similar to the given method", + longInfo = """|List methods to similar to the one based on graph edit distance. |""", - example = "callTree(method full name)" - ) - def callTree(callerFullName: String, tree: ListBuffer[String] = new ListBuffer[String](), depth: Int = 3)(implicit - atom: Cpg - ): ListBuffer[String] = { - var dashCount = 0 - var lastCallerMethod = callerFullName - var lastDashCount = 0 - tree += callerFullName - - def findCallee(methodName: String, tree: ListBuffer[String]): ListBuffer[String] = { - val calleeList = atom.method.fullNameExact(methodName).callee.whereNot(_.name(".* - tree += s"${printDashes(dashCount)}${c.fullName}~~${c.location.filename}#${c.lineNumber.getOrElse(0)}" - findCallee(c.fullName, tree) + example = "showSimilar(method full name)" + ) + def showSimilar( + methodFullName: String, + comparePattern: String = "", + upper_bound: Int = 500, + timeout: Int = 5, + as_text: Boolean = false + ): String = + val table = + richTableLib.Table( + title = s"Similarity analysis for `${methodFullName}`", + highlight = true, + show_lines = true + ) + val progress = richProgressLib.Progress(transient = true) + val first_method = atom.method.fullNameExact(methodFullName).gml + table.add_column("File Name") + table.add_column("Method Name") + table.add_column("Edit Distance") + val methodDistances = mutable.ArrayBuffer[MethodDistance]() + val base = if comparePattern.nonEmpty then atom.method.fullName(s".*${comparePattern}.*") + else atom.method.internal + base.whereNot(_.fullNameExact(methodFullName)).foreach { method => + py.`with`(progress) { mprogress => + val task = + mprogress.add_task(s"Analyzing ${method.fullName}", start = false, total = 100) + val edit_distance = + Torch.edit_distance( + first_method, + method.iterator.gml, + upper_bound = upper_bound, + timeout = timeout + ) + if edit_distance != -1 then + methodDistances += MethodDistance( + method.location.filename, + method.fullName, + edit_distance + ) + mprogress.stop_task(task) + mprogress.update(task, completed = 100) + } } - } - tree - } - - findCallee(lastCallerMethod, tree) - } - - def applyPostProcessingPasses(cpg: Cpg): Cpg = { - new CpgGeneratorFactory(_config).forLanguage(cpg.metaData.language.l.head) match { - case Some(frontend) => frontend.applyPostProcessingPasses(cpg) - case None => cpg - } - } - - def applyDefaultOverlays(cpg: Cpg): Cpg = { - val appliedOverlays = Overlays.appliedOverlays(cpg) - if (appliedOverlays.isEmpty) { - report("Adding default overlays to base CPG") - _runAnalyzer(defaultOverlayCreators(): _*) - } - cpg - } - - def _runAnalyzer(overlayCreators: LayerCreator*): Cpg = { - - overlayCreators.foreach { creator => - val overlayDirName = - workspace.getNextOverlayDirName(cpg, creator.overlayName) - - val projectOpt = workspace.projectByCpg(cpg) - if (projectOpt.isEmpty) { - throw new RuntimeException("No record for atom. Please use `importCode`/`importAtom/open`") - } - - if (projectOpt.get.appliedOverlays.contains(creator.overlayName)) { - report(s"Overlay ${creator.overlayName} already exists - skipping") - } else { - File(overlayDirName).createDirectories() - runCreator(creator, Some(overlayDirName)) - } - } - report( - "The graph has been modified. You may want to use the `save` command to persist changes to disk. All changes will also be saved collectively on exit" + methodDistances.sortInPlaceBy[Double](x => x.editDistance) + methodDistances.foreach(row => + table.add_row(row.filename, row.fullName, "" + row.editDistance) + ) + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end showSimilar + + def printDashes(count: Int) = + var tabStr = "+--- " + var i = 0 + while i < count do + tabStr = "| " + tabStr + i += 1 + tabStr + + @Doc( + info = "Show call tree for the given method", + longInfo = """|Show the call tree for the given method. + |""", + example = "callTree(method full name)" ) - cpg - } - - protected def runCreator(creator: LayerCreator, overlayDirName: Option[String]): Unit = { - val context = new LayerCreatorContext(cpg, overlayDirName) - creator.run(context, storeUndoInfo = true) - } - - // We still tie the project name to the input path here - // if no project name has been provided. - - def fixProjectNameAndComplainOnFix(name: String): String = { - val projectName = Some(name) - .filter(_.contains(java.io.File.separator)) - .map(x => deriveNameFromInputPath(x, workspace)) - .getOrElse(name) - if (name != projectName) { - System.err.println("Passing paths to `loadCpg` is deprecated, please use a project name") - } - projectName - } - -} - -object Console { - val nameOfLegacyCpgInProject = "app.atom" - - def deriveNameFromInputPath[T <: Project](inputPath: String, workspace: WorkspaceManager[T]): String = { - val name = File(inputPath).name - val project = workspace.project(name) - if (project.isDefined && project.exists(_.inputPath != inputPath)) { - var i = 1 - while (workspace.project(name + i).isDefined) { - i += 1 - } - name + i - } else { - name - } - } -} + def callTree( + callerFullName: String, + tree: ListBuffer[String] = new ListBuffer[String](), + depth: Int = 3 + )(implicit atom: Cpg): ListBuffer[String] = + var dashCount = 0 + var lastCallerMethod = callerFullName + var lastDashCount = 0 + tree += callerFullName + + def findCallee(methodName: String, tree: ListBuffer[String]): ListBuffer[String] = + val calleeList = + atom.method.fullNameExact(methodName).callee.whereNot(_.name(".* + tree += s"${printDashes(dashCount)}${c.fullName}~~${c.location.filename}#${c.lineNumber.getOrElse(0)}" + findCallee(c.fullName, tree) + } + tree + + findCallee(lastCallerMethod, tree) + end callTree + + def applyPostProcessingPasses(cpg: Cpg): Cpg = + new CpgGeneratorFactory(_config).forLanguage(cpg.metaData.language.l.head) match + case Some(frontend) => frontend.applyPostProcessingPasses(cpg) + case None => cpg + + def applyDefaultOverlays(cpg: Cpg): Cpg = + val appliedOverlays = Overlays.appliedOverlays(cpg) + if appliedOverlays.isEmpty then + report("Adding default overlays to base CPG") + _runAnalyzer(defaultOverlayCreators()*) + cpg + + def _runAnalyzer(overlayCreators: LayerCreator*): Cpg = + + overlayCreators.foreach { creator => + val overlayDirName = + workspace.getNextOverlayDirName(cpg, creator.overlayName) + + val projectOpt = workspace.projectByCpg(cpg) + if projectOpt.isEmpty then + throw new RuntimeException( + "No record for atom. Please use `importCode`/`importAtom/open`" + ) + + if projectOpt.get.appliedOverlays.contains(creator.overlayName) then + report(s"Overlay ${creator.overlayName} already exists - skipping") + else + File(overlayDirName).createDirectories() + runCreator(creator, Some(overlayDirName)) + } + report( + "The graph has been modified. You may want to use the `save` command to persist changes to disk. All changes will also be saved collectively on exit" + ) + cpg + end _runAnalyzer + + protected def runCreator(creator: LayerCreator, overlayDirName: Option[String]): Unit = + val context = new LayerCreatorContext(cpg, overlayDirName) + creator.run(context, storeUndoInfo = true) + + // We still tie the project name to the input path here + // if no project name has been provided. + + def fixProjectNameAndComplainOnFix(name: String): String = + val projectName = Some(name) + .filter(_.contains(java.io.File.separator)) + .map(x => deriveNameFromInputPath(x, workspace)) + .getOrElse(name) + if name != projectName then + System.err.println( + "Passing paths to `loadCpg` is deprecated, please use a project name" + ) + projectName +end Console + +object Console: + val nameOfLegacyCpgInProject = "app.atom" + + def deriveNameFromInputPath[T <: Project]( + inputPath: String, + workspace: WorkspaceManager[T] + ): String = + val name = File(inputPath).name + val project = workspace.project(name) + if project.isDefined && project.exists(_.inputPath != inputPath) then + var i = 1 + while workspace.project(name + i).isDefined do + i += 1 + name + i + else + name class ConsoleException(message: String, cause: Option[Throwable]) extends RuntimeException(message, cause.orNull) - with NoStackTrace { - def this(message: String) = this(message, None) - def this(message: String, cause: Throwable) = this(message, Option(cause)) -} + with NoStackTrace: + def this(message: String) = this(message, None) + def this(message: String, cause: Throwable) = this(message, Option(cause)) diff --git a/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala b/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala index 9fc3ebe2..f7ee467a 100644 --- a/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala +++ b/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala @@ -1,6 +1,6 @@ package io.appthreat.console -import better.files._ +import better.files.* import scala.annotation.tailrec import scala.collection.mutable @@ -10,49 +10,52 @@ import scala.collection.mutable * @param environment * A map of system environment variables. */ -class InstallConfig(environment: Map[String, String] = sys.env) { +class InstallConfig(environment: Map[String, String] = sys.env): - /** determining the root path of the joern/ocular installation is rather complex unfortunately, because we support a - * variety of use cases: - * - running the installed distribution from the install dir - * - running the installed distribution anywhere else on the system - * - running a locally staged ocular/joern build (via `sbt stage` and then either `./joern` or `cd - * platform/target/universal/stage; ./joern`) - * - running a unit/integration test (note: the jars would be in the local cache, e.g. in ~/.coursier/cache) - */ - lazy val rootPath: File = { - if (environment.contains("CHEN_INSTALL_DIR")) { - environment("CHEN_INSTALL_DIR").toFile - } else { - val uriToLibDir = classOf[InstallConfig].getProtectionDomain.getCodeSource.getLocation.toURI - val pathToLibDir = File(uriToLibDir).parent - findRootDirectory(pathToLibDir).getOrElse { - val cwd = File.currentWorkingDirectory - findRootDirectory(cwd).getOrElse(throw new AssertionError(s"""unable to find root installation directory + /** determining the root path of the joern/ocular installation is rather complex unfortunately, + * because we support a variety of use cases: + * - running the installed distribution from the install dir + * - running the installed distribution anywhere else on the system + * - running a locally staged ocular/joern build (via `sbt stage` and then either `./joern` + * or `cd platform/target/universal/stage; ./joern`) + * - running a unit/integration test (note: the jars would be in the local cache, e.g. in + * ~/.coursier/cache) + */ + lazy val rootPath: File = + if environment.contains("CHEN_INSTALL_DIR") then + environment("CHEN_INSTALL_DIR").toFile + else + val uriToLibDir = + classOf[InstallConfig].getProtectionDomain.getCodeSource.getLocation.toURI + val pathToLibDir = File(uriToLibDir).parent + findRootDirectory(pathToLibDir).getOrElse { + val cwd = File.currentWorkingDirectory + findRootDirectory(cwd).getOrElse( + throw new AssertionError(s"""unable to find root installation directory | context: tried to find marker file `$rootDirectoryMarkerFilename` | started search in both $pathToLibDir and $cwd and searched - | $maxSearchDepth directories upwards""".stripMargin)) - } - } - } + | $maxSearchDepth directories upwards""".stripMargin) + ) + } - private val rootDirectoryMarkerFilename = ".installation_root" - private val maxSearchDepth = 10 + private val rootDirectoryMarkerFilename = ".installation_root" + private val maxSearchDepth = 10 - @tailrec - private def findRootDirectory(currentSearchDir: File, currentSearchDepth: Int = 0): Option[File] = { - if (currentSearchDir.list.map(_.name).contains(rootDirectoryMarkerFilename)) - Some(currentSearchDir) - else if (currentSearchDepth < maxSearchDepth && currentSearchDir.parentOption.isDefined) - findRootDirectory(currentSearchDir.parent) - else - None - } -} + @tailrec + private def findRootDirectory( + currentSearchDir: File, + currentSearchDepth: Int = 0 + ): Option[File] = + if currentSearchDir.list.map(_.name).contains(rootDirectoryMarkerFilename) then + Some(currentSearchDir) + else if currentSearchDepth < maxSearchDepth && currentSearchDir.parentOption.isDefined then + findRootDirectory(currentSearchDir.parent) + else + None +end InstallConfig -object InstallConfig { - def apply(): InstallConfig = new InstallConfig() -} +object InstallConfig: + def apply(): InstallConfig = new InstallConfig() class ConsoleConfig( val install: InstallConfig = InstallConfig(), @@ -60,25 +63,20 @@ class ConsoleConfig( val tools: ToolsConfig = ToolsConfig() ) {} -object ToolsConfig { +object ToolsConfig: - private val osSpecificOpenCmd: String = { - if (scala.util.Properties.isWin) "start" - else if (scala.util.Properties.isMac) "open" - else "xdg-open" - } + private val osSpecificOpenCmd: String = + if scala.util.Properties.isWin then "start" + else if scala.util.Properties.isMac then "open" + else "xdg-open" - def apply(): ToolsConfig = new ToolsConfig() -} + def apply(): ToolsConfig = new ToolsConfig() class ToolsConfig(var imageViewer: String = ToolsConfig.osSpecificOpenCmd) -class FrontendConfig(var cmdLineParams: Iterable[String] = mutable.Buffer()) { - def withArgs(args: Iterable[String]): FrontendConfig = { - new FrontendConfig(cmdLineParams ++ args) - } -} +class FrontendConfig(var cmdLineParams: Iterable[String] = mutable.Buffer()): + def withArgs(args: Iterable[String]): FrontendConfig = + new FrontendConfig(cmdLineParams ++ args) -object FrontendConfig { - def apply(): FrontendConfig = new FrontendConfig() -} +object FrontendConfig: + def apply(): FrontendConfig = new FrontendConfig() diff --git a/console/src/main/scala/io/appthreat/console/CpgConverter.scala b/console/src/main/scala/io/appthreat/console/CpgConverter.scala index cbe67dd5..d6ed755c 100644 --- a/console/src/main/scala/io/appthreat/console/CpgConverter.scala +++ b/console/src/main/scala/io/appthreat/console/CpgConverter.scala @@ -3,12 +3,10 @@ package io.appthreat.console import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, CpgLoaderConfig} import overflowdb.Config -object CpgConverter { +object CpgConverter: - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = { - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close - } - -} + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + CpgLoader.load(srcFilename, config).close diff --git a/console/src/main/scala/io/appthreat/console/Help.scala b/console/src/main/scala/io/appthreat/console/Help.scala index 777b52be..de8290d3 100644 --- a/console/src/main/scala/io/appthreat/console/Help.scala +++ b/console/src/main/scala/io/appthreat/console/Help.scala @@ -1,23 +1,23 @@ package io.appthreat.console import org.apache.commons.lang.WordUtils -import overflowdb.traversal.help.DocFinder._ +import overflowdb.traversal.help.DocFinder.* import overflowdb.traversal.help.{Table, DocFinder} -object Help { +object Help: - private val width = 80 + private val width = 80 - def overview(clazz: Class[_]): String = { - val columnNames = List("command", "description", "example") - val rows = DocFinder - .findDocumentedMethodsOf(clazz) - .map { case StepDoc(_, funcName, doc) => - List(funcName, doc.info, doc.example) - } - .toList ++ List(runRow) + def overview(clazz: Class[?]): String = + val columnNames = List("command", "description", "example") + val rows = DocFinder + .findDocumentedMethodsOf(clazz) + .map { case StepDoc(_, funcName, doc) => + List(funcName, doc.info, doc.example) + } + .toList ++ List(runRow) - val header = formatNoQuotes(""" + val header = formatNoQuotes(""" | |Welcome to the interactive help system. Below you find |a table of all available top-level commands. To get @@ -29,42 +29,40 @@ object Help { | | |""".stripMargin) - header + "\n" + Table(columnNames, rows.sortBy(_.head)).render - } + header + "\n" + Table(columnNames, rows.sortBy(_.head)).render + end overview - def format(text: String): String = { - "\"\"\"" + "\n" + formatNoQuotes(text) + "\"\"\"" - } + def format(text: String): String = + "\"\"\"" + "\n" + formatNoQuotes(text) + "\"\"\"" - def formatNoQuotes(text: String): String = { - text.stripMargin - .split("\n\n") - .map(x => WordUtils.wrap(x.replace("\n", " "), width)) - .mkString("\n\n") - .trim - } + def formatNoQuotes(text: String): String = + text.stripMargin + .split("\n\n") + .map(x => WordUtils.wrap(x.replace("\n", " "), width)) + .mkString("\n\n") + .trim - private def runRow: List[String] = - List("run", "Run analyzer on active CPG", "run.securityprofile") + private def runRow: List[String] = + List("run", "Run analyzer on active CPG", "run.securityprofile") - // Since `run` is generated dynamically, it's not picked up when looking - // through methods via reflection, and therefore, we are adding - // it manually. - def runLongHelp: String = - Help.format(""" + // Since `run` is generated dynamically, it's not picked up when looking + // through methods via reflection, and therefore, we are adding + // it manually. + def runLongHelp: String = + Help.format(""" | |""".stripMargin) - def codeForHelpCommand(clazz: Class[_]): String = { - val membersCode = DocFinder - .findDocumentedMethodsOf(clazz) - .map { case StepDoc(_, funcName, doc) => - s" val $funcName: String = ${Help.format(doc.longInfo)}" - } - .mkString("\n") + def codeForHelpCommand(clazz: Class[?]): String = + val membersCode = DocFinder + .findDocumentedMethodsOf(clazz) + .map { case StepDoc(_, funcName, doc) => + s" val $funcName: String = ${Help.format(doc.longInfo)}" + } + .mkString("\n") - val overview = Help.overview(clazz) - s""" + val overview = Help.overview(clazz) + s""" | class Helper() { | def run: String = Help.runLongHelp | override def toString: String = \"\"\"$overview\"\"\" @@ -74,6 +72,5 @@ object Help { | | val help = new Helper |""".stripMargin - } - -} + end codeForHelpCommand +end Help diff --git a/console/src/main/scala/io/appthreat/console/JProduct.scala b/console/src/main/scala/io/appthreat/console/JProduct.scala index fd63242c..ab5765d2 100644 --- a/console/src/main/scala/io/appthreat/console/JProduct.scala +++ b/console/src/main/scala/io/appthreat/console/JProduct.scala @@ -1,4 +1,6 @@ package io.appthreat.console -sealed trait JProduct { def name: String } -case object ChenProduct extends JProduct { val name: String = "chen" } +sealed trait JProduct: + def name: String +case object ChenProduct extends JProduct: + val name: String = "chen" diff --git a/console/src/main/scala/io/appthreat/console/PluginManager.scala b/console/src/main/scala/io/appthreat/console/PluginManager.scala index f661b447..1b475c6a 100644 --- a/console/src/main/scala/io/appthreat/console/PluginManager.scala +++ b/console/src/main/scala/io/appthreat/console/PluginManager.scala @@ -1,6 +1,6 @@ package io.appthreat.console -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File import better.files.File.apply @@ -9,101 +9,88 @@ import scala.util.{Failure, Success, Try} /** Plugin management component * - * Joern allows plugins to be installed. A plugin at the very least consists of a class that inherits from - * `LayerCreator`, bundled in a jar file, packaged in a zip file. The zip file may furthermore contain any dependency - * jars that the plugin requires and that are not included on the joern class path by default. + * Joern allows plugins to be installed. A plugin at the very least consists of a class that + * inherits from `LayerCreator`, bundled in a jar file, packaged in a zip file. The zip file may + * furthermore contain any dependency jars that the plugin requires and that are not included on + * the joern class path by default. * * @param installDir * the Joern/Ocular installation dir */ -class PluginManager(val installDir: File) { +class PluginManager(val installDir: File): - /** Generate a sorted list of all installed plugins by examining the plugin directory. - */ - def listPlugins(): List[String] = { - val installedPluginNames = pluginDir.toList - .flatMap { dir => - File(dir).list.toList.flatMap { f => - "^joernext-(.*?)-.*$".r.findAllIn(f.name).matchData.map { m => - m.group(1) - } - } - } - .distinct - .sorted - installedPluginNames - } + /** Generate a sorted list of all installed plugins by examining the plugin directory. + */ + def listPlugins(): List[String] = + val installedPluginNames = pluginDir.toList + .flatMap { dir => + File(dir).list.toList.flatMap { f => + "^joernext-(.*?)-.*$".r.findAllIn(f.name).matchData.map { m => + m.group(1) + } + } + } + .distinct + .sorted + installedPluginNames - /** Install the plugin stored at `filename`. The plugin is expected to be a zip file containing Java archives (.jar - * files). - */ - def add(filename: String): Unit = { - if (pluginDir.isEmpty) { - println("Plugin directory does not exist") - return - } - val file = File(filename) - if (!file.exists) { - println(s"The file $filename does not exist") - } else { - addExisting(file) - } - } + /** Install the plugin stored at `filename`. The plugin is expected to be a zip file containing + * Java archives (.jar files). + */ + def add(filename: String): Unit = + if pluginDir.isEmpty then + println("Plugin directory does not exist") + return + val file = File(filename) + if !file.exists then + println(s"The file $filename does not exist") + else + addExisting(file) - private def addExisting(file: File): Unit = { - val pluginName = file.name.replace(".zip", "") - val tmpDir = extractToTemporaryDir(file) - tmpDir.foreach(dir => addExistingUnzipped(dir, pluginName)) - } + private def addExisting(file: File): Unit = + val pluginName = file.name.replace(".zip", "") + val tmpDir = extractToTemporaryDir(file) + tmpDir.foreach(dir => addExistingUnzipped(dir, pluginName)) - private def addExistingUnzipped(file: File, pluginName: String): Unit = { - file.listRecursively.filter(_.name.endsWith(".jar")).foreach { jar => - pluginDir.foreach { pDir => - if (!(pDir / jar.name).exists) { - val dstFileName = s"joernext-$pluginName-${jar.name}" - val dstFile = pDir / dstFileName - cp(jar, dstFile) + private def addExistingUnzipped(file: File, pluginName: String): Unit = + file.listRecursively.filter(_.name.endsWith(".jar")).foreach { jar => + pluginDir.foreach { pDir => + if !(pDir / jar.name).exists then + val dstFileName = s"joernext-$pluginName-${jar.name}" + val dstFile = pDir / dstFileName + cp(jar, dstFile) + } } - } - } - } - - private def extractToTemporaryDir(file: File) = { - Try { file.unzip() } match { - case Success(dir) => - Some(dir) - case Failure(exc) => - println("Error reading zip: " + exc.getMessage) - None - } - } - /** Delete plugin with given `name` from the plugin directory. - */ - def rm(name: String): List[String] = { - if (!listPlugins().contains(name)) { - List() - } else { - val filesToRemove = pluginDir.toList.flatMap { dir => - dir.list.filter { f => - f.name.startsWith(s"joernext-$name") - } - } - filesToRemove.foreach(f => f.delete()) - filesToRemove.map(_.pathAsString) - } - } + private def extractToTemporaryDir(file: File) = + Try { file.unzip() } match + case Success(dir) => + Some(dir) + case Failure(exc) => + println("Error reading zip: " + exc.getMessage) + None - /** Return the path to the plugin directory or None if the plugin directory does not exist. - */ - def pluginDir: Option[Path] = { - val pathToPluginDir = installDir.path.resolve("lib") - if (pathToPluginDir.toFile.exists()) { - Some(pathToPluginDir) - } else { - println(s"Plugin directory at $pathToPluginDir does not exist") - None - } - } + /** Delete plugin with given `name` from the plugin directory. + */ + def rm(name: String): List[String] = + if !listPlugins().contains(name) then + List() + else + val filesToRemove = pluginDir.toList.flatMap { dir => + dir.list.filter { f => + f.name.startsWith(s"joernext-$name") + } + } + filesToRemove.foreach(f => f.delete()) + filesToRemove.map(_.pathAsString) -} + /** Return the path to the plugin directory or None if the plugin directory does not exist. + */ + def pluginDir: Option[Path] = + val pathToPluginDir = installDir.path.resolve("lib") + if pathToPluginDir.toFile.exists() then + Some(pathToPluginDir) + else + println(s"Plugin directory at $pathToPluginDir does not exist") + None +end PluginManager diff --git a/console/src/main/scala/io/appthreat/console/Reporting.scala b/console/src/main/scala/io/appthreat/console/Reporting.scala index ec379cba..08a15d2a 100644 --- a/console/src/main/scala/io/appthreat/console/Reporting.scala +++ b/console/src/main/scala/io/appthreat/console/Reporting.scala @@ -3,43 +3,40 @@ package io.appthreat.console import java.io.OutputStream import scala.collection.mutable -trait Reporting { +trait Reporting: - def reportOutStream: OutputStream = System.err + def reportOutStream: OutputStream = System.err - def report(string: String): Unit = { - reportOutStream.write((string + "\n").getBytes("UTF-8")) - GlobalReporting.appendToGlobalStdOut(string) - } -} + def report(string: String): Unit = + reportOutStream.write((string + "\n").getBytes("UTF-8")) + GlobalReporting.appendToGlobalStdOut(string) -/** A dirty hack to capture the reported output for the server-mode. Context: server mode is a bit tricky, because the - * reporting happens inside the repl, but we want to retrieve it from the context _outside_ the repl, and the two have - * separate classloaders. There's probably a cleaner way, but for now this serves our needs. +/** A dirty hack to capture the reported output for the server-mode. Context: server mode is a bit + * tricky, because the reporting happens inside the repl, but we want to retrieve it from the + * context _outside_ the repl, and the two have separate classloaders. There's probably a cleaner + * way, but for now this serves our needs. * - * Note that this convolutes the output from concurrently-running jobs - so we should not run UserRunnables - * concurrently. + * Note that this convolutes the output from concurrently-running jobs - so we should not run + * UserRunnables concurrently. */ -object GlobalReporting { - private var enabled = false +object GlobalReporting: + private var enabled = false - def enable(): Unit = - enabled = true + def enable(): Unit = + enabled = true - def isEnabled(): Boolean = enabled + def isEnabled(): Boolean = enabled - def disable(): Unit = - enabled = false + def disable(): Unit = + enabled = false - private val globalStdOut = new mutable.StringBuilder + private val globalStdOut = new mutable.StringBuilder - def appendToGlobalStdOut(s: String): Unit = { - if (enabled) globalStdOut.append(s + System.lineSeparator()) - } + def appendToGlobalStdOut(s: String): Unit = + if enabled then globalStdOut.append(s + System.lineSeparator()) - def getAndClearGlobalStdOut(): String = { - val result = globalStdOut.result() - globalStdOut.clear() - result - } -} + def getAndClearGlobalStdOut(): String = + val result = globalStdOut.result() + globalStdOut.clear() + result +end GlobalReporting diff --git a/console/src/main/scala/io/appthreat/console/Run.scala b/console/src/main/scala/io/appthreat/console/Run.scala index d10a2533..f3df8b5c 100644 --- a/console/src/main/scala/io/appthreat/console/Run.scala +++ b/console/src/main/scala/io/appthreat/console/Run.scala @@ -6,58 +6,58 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext} import org.reflections8.Reflections import org.reflections8.util.{ClasspathHelper, ConfigurationBuilder} -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -object Run { +object Run: - def runCustomQuery(console: Console[_], query: HasStoreMethod): Unit = { - console._runAnalyzer(new LayerCreator { - override val overlayName: String = "custom" - override val description: String = "A custom pass" + def runCustomQuery(console: Console[?], query: HasStoreMethod): Unit = + console._runAnalyzer(new LayerCreator: + override val overlayName: String = "custom" + override val description: String = "A custom pass" - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val pass: CpgPass = new CpgPass(console.cpg) { - override val name = "custom" - override def run(builder: DiffGraphBuilder): Unit = { - query.store()(builder) - } - } - runPass(pass, context, storeUndoInfo) - } - }) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val pass: CpgPass = new CpgPass(console.cpg): + override val name = "custom" + override def run(builder: DiffGraphBuilder): Unit = + query.store()(builder) + runPass(pass, context, storeUndoInfo) + ) - /** Generate code for the run command - * @param exclude - * list of analyzers to exclude (by full class name) - */ - def codeForRunCommand(exclude: List[String] = List()): String = { - val ioSlLayerCreators = creatorsFor("io.shiftleft", exclude) - val ioJoernLayerCreators = creatorsFor("io.appthreat", exclude) - codeForLayerCreators((ioSlLayerCreators ++ ioJoernLayerCreators).distinct) - } + /** Generate code for the run command + * @param exclude + * list of analyzers to exclude (by full class name) + */ + def codeForRunCommand(exclude: List[String] = List()): String = + val ioSlLayerCreators = creatorsFor("io.shiftleft", exclude) + val ioJoernLayerCreators = creatorsFor("io.appthreat", exclude) + codeForLayerCreators((ioSlLayerCreators ++ ioJoernLayerCreators).distinct) - private def creatorsFor(namespace: String, exclude: List[String]) = { - new Reflections( - new ConfigurationBuilder().setUrls( - ClasspathHelper.forPackage(namespace, ClasspathHelper.contextClassLoader(), ClasspathHelper.staticClassLoader()) - ) - ).getSubTypesOf(classOf[LayerCreator]) - .asScala - .filterNot(t => t.isAnonymousClass || t.isLocalClass || t.isMemberClass || t.isSynthetic) - .filterNot(t => t.getName.startsWith("io.appthreat.console.Run")) - .toList - .map(t => (t.getSimpleName.toLowerCase, s"_root_.${t.getName}")) - .filter(t => !exclude.contains(t._2)) - } + private def creatorsFor(namespace: String, exclude: List[String]) = + new Reflections( + new ConfigurationBuilder().setUrls( + ClasspathHelper.forPackage( + namespace, + ClasspathHelper.contextClassLoader(), + ClasspathHelper.staticClassLoader() + ) + ) + ).getSubTypesOf(classOf[LayerCreator]) + .asScala + .filterNot(t => + t.isAnonymousClass || t.isLocalClass || t.isMemberClass || t.isSynthetic + ) + .filterNot(t => t.getName.startsWith("io.appthreat.console.Run")) + .toList + .map(t => (t.getSimpleName.toLowerCase, s"_root_.${t.getName}")) + .filter(t => !exclude.contains(t._2)) - private def codeForLayerCreators(layerCreatorTypeNames: List[(String, String)]): String = { - val optsMembersCode = layerCreatorTypeNames - .map { case (varName, typeName) => s"val $varName = $typeName.defaultOpts" } - .mkString("\n") + private def codeForLayerCreators(layerCreatorTypeNames: List[(String, String)]): String = + val optsMembersCode = layerCreatorTypeNames + .map { case (varName, typeName) => s"val $varName = $typeName.defaultOpts" } + .mkString("\n") - val optsCode = - s""" + val optsCode = + s""" |class OptsDynamic { |$optsMembersCode |} @@ -69,25 +69,27 @@ object Run { | def diffGraph = _diffGraph |""".stripMargin - val membersCode = layerCreatorTypeNames - .map { case (varName, typeName) => s" def $varName: Cpg = _runAnalyzer(new $typeName(opts.$varName))" } - .mkString("\n") + val membersCode = layerCreatorTypeNames + .map { case (varName, typeName) => + s" def $varName: Cpg = _runAnalyzer(new $typeName(opts.$varName))" + } + .mkString("\n") - val toStringCode = - s""" + val toStringCode = + s""" | import overflowdb.traversal.help.Table | override def toString() : String = { | val columnNames = List("name", "description") | val rows = | ${layerCreatorTypeNames.map { case (varName, typeName) => - s"""List("$varName",$typeName.description.trim)""" - }} + s"""List("$varName",$typeName.description.trim)""" + }} | "\\n" + Table(columnNames, rows).render | } |""".stripMargin - optsCode + - s""" + optsCode + + s""" |class OverlaysDynamic { | | def apply(query: _root_.io.shiftleft.semanticcpg.language.HasStoreMethod) = @@ -99,6 +101,5 @@ object Run { |} |val run = new OverlaysDynamic() |""".stripMargin - } - -} + end codeForLayerCreators +end Run diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala index 655b6338..9cb71bce 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala @@ -12,30 +12,29 @@ case class AtomGenerator( language: String, sliceMode: String = "usages", slicesFile: String = "usages.json" -) extends CpgGenerator { - private lazy val command: String = "atom" +) extends CpgGenerator: + private lazy val command: String = "atom" - /** Generate an atom for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = { - val arguments = Seq( - sliceMode, - "-s", - slicesFile, - "--output", - outputPath, - "--language", - language, - inputPath - ) ++ config.cmdLineParams - runShellCommand(command, arguments).map(_ => outputPath) - } + /** Generate an atom for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq( + sliceMode, + "-s", + slicesFile, + "--output", + outputPath, + "--language", + language, + inputPath + ) ++ config.cmdLineParams + runShellCommand(command, arguments).map(_ => outputPath) - override def isAvailable: Boolean = true + override def isAvailable: Boolean = true - override def applyPostProcessingPasses(atom: Cpg): Cpg = { - atom - } + override def applyPostProcessingPasses(atom: Cpg): Cpg = + atom - override def isJvmBased = false -} + override def isJvmBased = false +end AtomGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala index 275463ee..10de20c9 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala @@ -5,21 +5,21 @@ import io.appthreat.console.FrontendConfig import java.nio.file.Path import scala.util.Try -/** C/C++ language frontend that translates C/C++ source files into code property graphs using Eclipse CDT parsing / - * preprocessing. +/** C/C++ language frontend that translates C/C++ source files into code property graphs using + * Eclipse CDT parsing / preprocessing. */ -case class CCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("c2cpg.bat") else rootPath.resolve("c2cpg.sh") +case class CCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: + private lazy val command: Path = + if isWin then rootPath.resolve("c2cpg.bat") else rootPath.resolve("c2cpg.sh") - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = { - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - runShellCommand(command.toString, arguments).map(_ => outputPath) - } + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true -} + override def isJvmBased = true diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala index 92467177..c1b4994f 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala @@ -8,24 +8,26 @@ import scala.util.Try import better.files.File import better.files.File.LinkOptions -case class CdxGenerator(config: FrontendConfig, rootPath: Path, language: String) extends CpgGenerator { - private lazy val command: String = "cdxgen" +case class CdxGenerator(config: FrontendConfig, rootPath: Path, language: String) + extends CpgGenerator: + private lazy val command: String = "cdxgen" - /** Generate a CycloneDX BoM for the given input path. Returns the output path, or None, if no cdx was generated. - */ - override def generate(inputPath: String, outputPath: String = "bom.json"): Try[String] = { - var outFile = - if (File(outputPath).isDirectory(linkOptions = LinkOptions.noFollow)) (File(outputPath) / "bom.json").pathAsString - else outputPath - val arguments = Seq("-o", outFile, "-t", language, "--deep", inputPath) ++ config.cmdLineParams - runShellCommand(command, arguments).map(_ => outFile) - } + /** Generate a CycloneDX BoM for the given input path. Returns the output path, or None, if no + * cdx was generated. + */ + override def generate(inputPath: String, outputPath: String = "bom.json"): Try[String] = + var outFile = + if File(outputPath).isDirectory(linkOptions = LinkOptions.noFollow) then + (File(outputPath) / "bom.json").pathAsString + else outputPath + val arguments = + Seq("-o", outFile, "-t", language, "--deep", inputPath) ++ config.cmdLineParams + runShellCommand(command, arguments).map(_ => outFile) - override def isAvailable: Boolean = true + override def isAvailable: Boolean = true - override def applyPostProcessingPasses(atom: Cpg): Cpg = { - atom - } + override def applyPostProcessingPasses(atom: Cpg): Cpg = + atom - override def isJvmBased = false -} + override def isJvmBased = false +end CdxGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala index 4168ae72..caaef721 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala @@ -3,44 +3,45 @@ package io.appthreat.console.cpgcreation import better.files.File import io.shiftleft.codepropertygraph.Cpg -import scala.sys.process._ +import scala.sys.process.* import scala.util.Try -/** A CpgGenerator generates Code Property Graphs from code. Each supported language implements a Generator, e.g., - * [[JavaCpgGenerator]] implements Java Archive to CPG conversion, while [[CSharpCpgGenerator]] translates C# projects - * into code property graphs. +/** A CpgGenerator generates Code Property Graphs from code. Each supported language implements a + * Generator, e.g., [[JavaCpgGenerator]] implements Java Archive to CPG conversion, while + * [[CSharpCpgGenerator]] translates C# projects into code property graphs. */ -abstract class CpgGenerator() { - - def isWin: Boolean = scala.util.Properties.isWin - - def isAvailable: Boolean - - /** is this a JVM based frontend? if so, we'll invoke it with -Xmx for max heap settings */ - def isJvmBased: Boolean - - /** Generate a CPG for the given input path. Returns the output path, or a Failure, if no CPG was generated. - * - * This method appends command line options in config.frontend.cmdLineParams to the shell command. - */ - def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] - - protected def runShellCommand(program: String, arguments: Seq[String]): Try[Unit] = - Try { - val cmd = Seq(program) ++ performanceParameter ++ arguments - val exitValue = cmd.run().exitValue() - assert(exitValue == 0, s"Error running shell command: exitValue=$exitValue; $cmd") - } - - protected lazy val performanceParameter = { - if (isJvmBased) { - val maxValueInGigabytes = Math.floor(Runtime.getRuntime.maxMemory.toDouble / 1024 / 1024 / 1024).toInt - Seq(s"-J-Xmx${maxValueInGigabytes}G") - } else Nil - } - - /** override in specific cpg generators to make them apply post processing passes */ - def applyPostProcessingPasses(cpg: Cpg): Cpg = - cpg - -} +abstract class CpgGenerator(): + + def isWin: Boolean = scala.util.Properties.isWin + + def isAvailable: Boolean + + /** is this a JVM based frontend? if so, we'll invoke it with -Xmx for max heap settings */ + def isJvmBased: Boolean + + /** Generate a CPG for the given input path. Returns the output path, or a Failure, if no CPG + * was generated. + * + * This method appends command line options in config.frontend.cmdLineParams to the shell + * command. + */ + def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] + + protected def runShellCommand(program: String, arguments: Seq[String]): Try[Unit] = + Try { + val cmd = Seq(program) ++ performanceParameter ++ arguments + val exitValue = cmd.run().exitValue() + assert(exitValue == 0, s"Error running shell command: exitValue=$exitValue; $cmd") + } + + protected lazy val performanceParameter = + if isJvmBased then + val maxValueInGigabytes = + Math.floor(Runtime.getRuntime.maxMemory.toDouble / 1024 / 1024 / 1024).toInt + Seq(s"-J-Xmx${maxValueInGigabytes}G") + else Nil + + /** override in specific cpg generators to make them apply post processing passes */ + def applyPostProcessingPasses(cpg: Cpg): Cpg = + cpg +end CpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala index a31408d5..86b9866a 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala @@ -10,83 +10,85 @@ import overflowdb.Config import java.nio.file.Path import scala.util.Try -object CpgGeneratorFactory { - private val KNOWN_LANGUAGES = Set( - Languages.C, - Languages.CSHARP, - Languages.GOLANG, - Languages.GHIDRA, - Languages.JAVA, - Languages.JAVASCRIPT, - Languages.JSSRC, - Languages.PYTHON, - Languages.PYTHONSRC, - Languages.LLVM, - Languages.PHP, - Languages.KOTLIN, - Languages.NEWC, - Languages.JAVASRC - ) -} +object CpgGeneratorFactory: + private val KNOWN_LANGUAGES = Set( + Languages.C, + Languages.CSHARP, + Languages.GOLANG, + Languages.GHIDRA, + Languages.JAVA, + Languages.JAVASCRIPT, + Languages.JSSRC, + Languages.PYTHON, + Languages.PYTHONSRC, + Languages.LLVM, + Languages.PHP, + Languages.KOTLIN, + Languages.NEWC, + Languages.JAVASRC + ) -class CpgGeneratorFactory(config: ConsoleConfig) { +class CpgGeneratorFactory(config: ConsoleConfig): - /** For a given input path, try to guess a suitable generator and return it - */ - def forCodeAt(inputPath: String): Option[CpgGenerator] = - for { - language <- guessLanguage(inputPath) - cpgGenerator <- cpgGeneratorForLanguage(language, config.frontend, config.install.rootPath.path, args = Nil) - } yield { - cpgGenerator - } + /** For a given input path, try to guess a suitable generator and return it + */ + def forCodeAt(inputPath: String): Option[CpgGenerator] = + for + language <- guessLanguage(inputPath) + cpgGenerator <- cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args = Nil + ) + yield cpgGenerator - /** For a language, return the generator - */ - def forLanguage(language: String): Option[CpgGenerator] = { - Option(language.toUpperCase()) - .filter(languageIsKnown) - .flatMap { lang => - cpgGeneratorForLanguage(lang, config.frontend, config.install.rootPath.path, args = Nil) - } - } + /** For a language, return the generator + */ + def forLanguage(language: String): Option[CpgGenerator] = + Option(language.toUpperCase()) + .filter(languageIsKnown) + .flatMap { lang => + cpgGeneratorForLanguage( + lang, + config.frontend, + config.install.rootPath.path, + args = Nil + ) + } - def languageIsKnown(language: String): Boolean = CpgGeneratorFactory.KNOWN_LANGUAGES.contains(language) + def languageIsKnown(language: String): Boolean = + CpgGeneratorFactory.KNOWN_LANGUAGES.contains(language) - def runGenerator(generator: CpgGenerator, inputPath: String, outputPath: String): Try[Path] = { - val outputFileOpt: Try[File] = - generator.generate(inputPath, outputPath).map(File(_)) - outputFileOpt.map { outFile => - val parentPath = outFile.parent.path.toAbsolutePath - if (isZipFile(outFile)) { - val srcFilename = outFile.path.toAbsolutePath.toString - val dstFilename = parentPath.resolve("cpg.bin").toAbsolutePath.toString - // MemoryHelper.hintForInsufficientMemory(srcFilename).map(report) - convertProtoCpgToOverflowDb(srcFilename, dstFilename) - } else { - val srcPath = parentPath.resolve("app.atom") - if (srcPath.toFile.exists()) { - mv(srcPath, parentPath.resolve("cpg.bin")) + def runGenerator(generator: CpgGenerator, inputPath: String, outputPath: String): Try[Path] = + val outputFileOpt: Try[File] = + generator.generate(inputPath, outputPath).map(File(_)) + outputFileOpt.map { outFile => + val parentPath = outFile.parent.path.toAbsolutePath + if isZipFile(outFile) then + val srcFilename = outFile.path.toAbsolutePath.toString + val dstFilename = parentPath.resolve("cpg.bin").toAbsolutePath.toString + // MemoryHelper.hintForInsufficientMemory(srcFilename).map(report) + convertProtoCpgToOverflowDb(srcFilename, dstFilename) + else + val srcPath = parentPath.resolve("app.atom") + if srcPath.toFile.exists() then + mv(srcPath, parentPath.resolve("cpg.bin")) + parentPath } - } - parentPath - } - } - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = { - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close - File(srcFilename).delete() - } + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + CpgLoader.load(srcFilename, config).close + File(srcFilename).delete() - def isZipFile(file: File): Boolean = { - val bytes = file.bytes - Try { - bytes.next() == 'P' && bytes.next() == 'K' - }.getOrElse(false) - } + def isZipFile(file: File): Boolean = + val bytes = file.bytes + Try { + bytes.next() == 'P' && bytes.next() == 'K' + }.getOrElse(false) - private def report(str: String): Unit = System.err.println(str) - -} + private def report(str: String): Unit = System.err.println(str) +end CpgGeneratorFactory diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala index 93e5e8ea..1526d2d2 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala @@ -15,149 +15,215 @@ import me.shadaj.scalapy.interpreter.CPythonInterpreter import java.nio.file.Path import scala.util.{Failure, Success, Try} -class ImportCode[T <: Project](console: Console[T]) extends Reporting { - import Console._ - - private val config = console.config - private val workspace = console.workspace - protected val generatorFactory = new CpgGeneratorFactory(config) - val chenPyUtils = py.module("chenpy.utils") - - private def checkInputPath(inputPath: String): Unit = { - if (!File(inputPath).exists) { - throw new ConsoleException(s"Input path does not exist: '$inputPath'") - } - } - - def importUrl(inputPath: String): String = chenPyUtils.import_url(inputPath).as[String] - - /** This is the `importCode(...)` method exposed on the console. It attempts to find a suitable CPG generator first by - * looking at the `language` parameter and if no generator is found for the language, looking the contents at - * `inputPath` to determine heuristically which generator to use. - */ - def apply(inputPath: String, projectName: String = "", language: String = ""): Cpg = { - var srcPath = - if ( - inputPath.startsWith("http") || inputPath.startsWith("git://") || inputPath.startsWith("CVE-") || inputPath - .startsWith("GHSA-") - ) importUrl(inputPath) - else inputPath - checkInputPath(srcPath) - if (language != "") { - generatorFactory.forLanguage(language.toUpperCase()) match { - case None => throw new ConsoleException(s"No Atom generator exists for language: $language") - case Some(frontend) => apply(frontend, srcPath, projectName) - } - } else { - generatorFactory.forCodeAt(srcPath) match { - case None => throw new ConsoleException(s"No suitable Atom generator found for: $srcPath") - case Some(frontend) => apply(frontend, srcPath, projectName) - } - } - } - - def c: SourceBasedFrontend = new CFrontend("c") - def cpp: SourceBasedFrontend = new CFrontend("cpp", extension = "cpp") - def java: SourceBasedFrontend = new SourceBasedFrontend("java", Languages.JAVASRC, "Java Source Frontend", "java") - def jvm: Frontend = - new BinaryFrontend("jvm", Languages.JAVA, "Java/Dalvik Bytecode Frontend (based on SOOT's jimple)") - def ghidra: Frontend = new BinaryFrontend("ghidra", Languages.GHIDRA, "ghidra reverse engineering frontend") - def kotlin: SourceBasedFrontend = - new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kotlin") - def python: SourceBasedFrontend = - new SourceBasedFrontend("python", Languages.PYTHONSRC, "Python Source Frontend", "py") - def golang: SourceBasedFrontend = new SourceBasedFrontend("golang", Languages.GOLANG, "Golang Source Frontend", "go") - def javascript: SourceBasedFrontend = - new SourceBasedFrontend("javascript", Languages.JAVASCRIPT, "Javascript Source Frontend", "js") - def jssrc: SourceBasedFrontend = - new SourceBasedFrontend("jssrc", Languages.JSSRC, "Javascript/Typescript Source Frontend based on astgen", "js") - def csharp: Frontend = new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") - def llvm: Frontend = new BinaryFrontend("llvm", Languages.LLVM, "LLVM Bitcode Frontend") - def php: SourceBasedFrontend = new SourceBasedFrontend("php", Languages.PHP, "PHP source frontend", "php") - def ruby: SourceBasedFrontend = new SourceBasedFrontend("ruby", Languages.RUBYSRC, "Ruby source frontend", "rb") - - private def allFrontends: List[Frontend] = - List(c, cpp, ghidra, kotlin, java, jvm, javascript, jssrc, golang, llvm, php, python, csharp, ruby) - - // this is only abstract to force people adding frontends to make a decision whether the frontend consumes binaries or source - abstract class Frontend(val name: String, val language: String, val description: String = "") { - def cpgGeneratorForLanguage( +class ImportCode[T <: Project](console: Console[T]) extends Reporting: + import Console.* + + private val config = console.config + private val workspace = console.workspace + protected val generatorFactory = new CpgGeneratorFactory(config) + val chenPyUtils = py.module("chenpy.utils") + + private def checkInputPath(inputPath: String): Unit = + if !File(inputPath).exists then + throw new ConsoleException(s"Input path does not exist: '$inputPath'") + + def importUrl(inputPath: String): String = chenPyUtils.import_url(inputPath).as[String] + + /** This is the `importCode(...)` method exposed on the console. It attempts to find a suitable + * CPG generator first by looking at the `language` parameter and if no generator is found for + * the language, looking the contents at `inputPath` to determine heuristically which generator + * to use. + */ + def apply(inputPath: String, projectName: String = "", language: String = ""): Cpg = + var srcPath = + if + inputPath.startsWith("http") || inputPath.startsWith( + "git://" + ) || inputPath.startsWith("CVE-") || inputPath + .startsWith("GHSA-") + then importUrl(inputPath) + else inputPath + checkInputPath(srcPath) + if language != "" then + generatorFactory.forLanguage(language.toUpperCase()) match + case None => + throw new ConsoleException(s"No Atom generator exists for language: $language") + case Some(frontend) => apply(frontend, srcPath, projectName) + else + generatorFactory.forCodeAt(srcPath) match + case None => + throw new ConsoleException(s"No suitable Atom generator found for: $srcPath") + case Some(frontend) => apply(frontend, srcPath, projectName) + end apply + + def c: SourceBasedFrontend = new CFrontend("c") + def cpp: SourceBasedFrontend = new CFrontend("cpp", extension = "cpp") + def java: SourceBasedFrontend = + new SourceBasedFrontend("java", Languages.JAVASRC, "Java Source Frontend", "java") + def jvm: Frontend = + new BinaryFrontend( + "jvm", + Languages.JAVA, + "Java/Dalvik Bytecode Frontend (based on SOOT's jimple)" + ) + def ghidra: Frontend = + new BinaryFrontend("ghidra", Languages.GHIDRA, "ghidra reverse engineering frontend") + def kotlin: SourceBasedFrontend = + new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kotlin") + def python: SourceBasedFrontend = + new SourceBasedFrontend("python", Languages.PYTHONSRC, "Python Source Frontend", "py") + def golang: SourceBasedFrontend = + new SourceBasedFrontend("golang", Languages.GOLANG, "Golang Source Frontend", "go") + def javascript: SourceBasedFrontend = + new SourceBasedFrontend( + "javascript", + Languages.JAVASCRIPT, + "Javascript Source Frontend", + "js" + ) + def jssrc: SourceBasedFrontend = + new SourceBasedFrontend( + "jssrc", + Languages.JSSRC, + "Javascript/Typescript Source Frontend based on astgen", + "js" + ) + def csharp: Frontend = + new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") + def llvm: Frontend = new BinaryFrontend("llvm", Languages.LLVM, "LLVM Bitcode Frontend") + def php: SourceBasedFrontend = + new SourceBasedFrontend("php", Languages.PHP, "PHP source frontend", "php") + def ruby: SourceBasedFrontend = + new SourceBasedFrontend("ruby", Languages.RUBYSRC, "Ruby source frontend", "rb") + + private def allFrontends: List[Frontend] = + List( + c, + cpp, + ghidra, + kotlin, + java, + jvm, + javascript, + jssrc, + golang, + llvm, + php, + python, + csharp, + ruby + ) + + // this is only abstract to force people adding frontends to make a decision whether the frontend consumes binaries or source + abstract class Frontend(val name: String, val language: String, val description: String = ""): + def cpgGeneratorForLanguage( + language: String, + config: FrontendConfig, + rootPath: Path, + args: List[String] + ): Option[CpgGenerator] = + io.appthreat.console.cpgcreation.cpgGeneratorForLanguage( + language, + config, + rootPath, + args + ) + + def isAvailable: Boolean = + cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args = Nil + ).exists(_.isAvailable) + + def apply(inputPath: String, projectName: String = "", args: List[String] = List()): Cpg = + val frontend = cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args + ) + .getOrElse(throw new ConsoleException( + s"no atom generator for language=$language available!" + )) + new ImportCode(console)(frontend, inputPath, projectName) + end Frontend + + private class BinaryFrontend(name: String, language: String, description: String = "") + extends Frontend(name, language, description) + + class SourceBasedFrontend( + name: String, language: String, - config: FrontendConfig, - rootPath: Path, - args: List[String] - ): Option[CpgGenerator] = - io.appthreat.console.cpgcreation.cpgGeneratorForLanguage(language, config, rootPath, args) - - def isAvailable: Boolean = - cpgGeneratorForLanguage(language, config.frontend, config.install.rootPath.path, args = Nil).exists(_.isAvailable) - - def apply(inputPath: String, projectName: String = "", args: List[String] = List()): Cpg = { - val frontend = cpgGeneratorForLanguage(language, config.frontend, config.install.rootPath.path, args) - .getOrElse(throw new ConsoleException(s"no atom generator for language=$language available!")) - new ImportCode(console)(frontend, inputPath, projectName) - } - } - - private class BinaryFrontend(name: String, language: String, description: String = "") - extends Frontend(name, language, description) - - class SourceBasedFrontend(name: String, language: String, description: String, extension: String) - extends Frontend(name, language, description) { - - def fromString(str: String, args: List[String] = List()): Cpg = { - withCodeInTmpFile(str, "tmp." + extension) { dir => - super.apply(dir.path.toString, args = args) - } match { - case Failure(exception) => throw new ConsoleException(s"unable to generate atom from given String", exception) - case Success(value) => value - } - } - } - class CFrontend(name: String, extension: String = "c") - extends SourceBasedFrontend(name, Languages.NEWC, "Eclipse CDT Based Frontend for C/C++", extension) - - private def withCodeInTmpFile(str: String, filename: String)(f: File => Cpg): Try[Cpg] = { - val dir = File.newTemporaryDirectory("console") - val result = Try { - (dir / filename).write(str) - f(dir) - } - dir.deleteOnExit(swallowIOExceptions = true) - result - } - - /** Provide an overview of the available CPG generators (frontends) - */ - override def toString: String = { - val cols = List("name", "description", "available") - val rows = allFrontends.map { frontend => - List(frontend.name, frontend.description, frontend.isAvailable.toString) - } - "Type `importCode.` to run a specific language frontend\n" + - "\n" + Table(cols, rows).render - } - - private def apply(generator: CpgGenerator, inputPath: String, projectName: String): Cpg = { - checkInputPath(inputPath) - - val name = Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath(inputPath, workspace)) - report(s"Creating project `$name` for code at `$inputPath`") - - val cpgMaybe = workspace.createProject(inputPath, name).flatMap { pathToProject => - val frontendCpgOutFile = pathToProject.resolve(nameOfLegacyCpgInProject) - val frontendAtomPath = generatorFactory.runGenerator(generator, inputPath, frontendCpgOutFile.toString) - frontendAtomPath match { - case Success(_) => - console.open(name).flatMap(_.cpg) - case Failure(exception) => - throw new ConsoleException(s"Error creating project for input path: `$inputPath`", exception) - } - } - - cpgMaybe - .map(cpg => console.summary) - .getOrElse(throw new ConsoleException(s"Error creating project for input path: `$inputPath`")) - cpgMaybe.get - } -} + description: String, + extension: String + ) extends Frontend(name, language, description): + + def fromString(str: String, args: List[String] = List()): Cpg = + withCodeInTmpFile(str, "tmp." + extension) { dir => + super.apply(dir.path.toString, args = args) + } match + case Failure(exception) => throw new ConsoleException( + s"unable to generate atom from given String", + exception + ) + case Success(value) => value + class CFrontend(name: String, extension: String = "c") + extends SourceBasedFrontend( + name, + Languages.NEWC, + "Eclipse CDT Based Frontend for C/C++", + extension + ) + + private def withCodeInTmpFile(str: String, filename: String)(f: File => Cpg): Try[Cpg] = + val dir = File.newTemporaryDirectory("console") + val result = Try { + (dir / filename).write(str) + f(dir) + } + dir.deleteOnExit(swallowIOExceptions = true) + result + + /** Provide an overview of the available CPG generators (frontends) + */ + override def toString: String = + val cols = List("name", "description", "available") + val rows = allFrontends.map { frontend => + List(frontend.name, frontend.description, frontend.isAvailable.toString) + } + "Type `importCode.` to run a specific language frontend\n" + + "\n" + Table(cols, rows).render + + private def apply(generator: CpgGenerator, inputPath: String, projectName: String): Cpg = + checkInputPath(inputPath) + + val name = Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( + inputPath, + workspace + )) + report(s"Creating project `$name` for code at `$inputPath`") + + val cpgMaybe = workspace.createProject(inputPath, name).flatMap { pathToProject => + val frontendCpgOutFile = pathToProject.resolve(nameOfLegacyCpgInProject) + val frontendAtomPath = + generatorFactory.runGenerator(generator, inputPath, frontendCpgOutFile.toString) + frontendAtomPath match + case Success(_) => + console.open(name).flatMap(_.cpg) + case Failure(exception) => + throw new ConsoleException( + s"Error creating project for input path: `$inputPath`", + exception + ) + } + + cpgMaybe + .map(cpg => console.summary) + .getOrElse( + throw new ConsoleException(s"Error creating project for input path: `$inputPath`") + ) + cpgMaybe.get + end apply +end ImportCode diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala index 00185766..afe4b098 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala @@ -11,26 +11,26 @@ import scala.util.Try /** Source-based front-end for Java */ -case class JavaSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("javasrc2cpg.bat") else rootPath.resolve("javasrc2cpg") - private var javaConfig: Option[Config] = None +case class JavaSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: + private lazy val command: Path = + if isWin then rootPath.resolve("javasrc2cpg.bat") else rootPath.resolve("javasrc2cpg") + private var javaConfig: Option[Config] = None - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = { - val arguments = config.cmdLineParams.toSeq ++ Seq(inputPath, "--output", outputPath) - javaConfig = X2Cpg.parseCommandLine(arguments.toArray, Main.getCmdLineParser, Config()) - runShellCommand(command.toString, arguments).map(_ => outputPath) - } + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = + val arguments = config.cmdLineParams.toSeq ++ Seq(inputPath, "--output", outputPath) + javaConfig = X2Cpg.parseCommandLine(arguments.toArray, Main.getCmdLineParser, Config()) + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def applyPostProcessingPasses(cpg: Cpg): Cpg = { - if (javaConfig.forall(_.enableTypeRecovery)) - JavaSrc2Cpg.typeRecoveryPasses(cpg, javaConfig).foreach(_.createAndApply()) - super.applyPostProcessingPasses(cpg) - } + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + if javaConfig.forall(_.enableTypeRecovery) then + JavaSrc2Cpg.typeRecoveryPasses(cpg, javaConfig).foreach(_.createAndApply()) + super.applyPostProcessingPasses(cpg) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true -} + override def isJvmBased = true +end JavaSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala index 5e84784c..3ddc73f2 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala @@ -6,18 +6,18 @@ import io.appthreat.console.FrontendConfig import java.nio.file.Path import scala.util.Try -case class JsCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("js2cpg.bat") else rootPath.resolve("js2cpg.sh") +case class JsCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: + private lazy val command: Path = + if isWin then rootPath.resolve("js2cpg.bat") else rootPath.resolve("js2cpg.sh") - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = { - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - runShellCommand(command.toString, arguments).map(_ => outputPath) - } + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true -} + override def isJvmBased = true diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala index b25e6ba6..fab6b930 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala @@ -9,25 +9,25 @@ import io.shiftleft.codepropertygraph.Cpg import java.nio.file.Path import scala.util.Try -case class JsSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("jssrc2cpg.bat") else rootPath.resolve("jssrc2cpg.sh") - private var jsConfig: Option[Config] = None +case class JsSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: + private lazy val command: Path = + if isWin then rootPath.resolve("jssrc2cpg.bat") else rootPath.resolve("jssrc2cpg.sh") + private var jsConfig: Option[Config] = None - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = { - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - jsConfig = X2Cpg.parseCommandLine(arguments.toArray, Frontend.cmdLineParser, Config()) - runShellCommand(command.toString, arguments).map(_ => outputPath) - } + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + jsConfig = X2Cpg.parseCommandLine(arguments.toArray, Frontend.cmdLineParser, Config()) + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def applyPostProcessingPasses(cpg: Cpg): Cpg = { - JsSrc2Cpg.postProcessingPasses(cpg, jsConfig).foreach(_.createAndApply()) - cpg - } + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + JsSrc2Cpg.postProcessingPasses(cpg, jsConfig).foreach(_.createAndApply()) + cpg - override def isJvmBased = true -} + override def isJvmBased = true +end JsSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala index db69737d..05579d7a 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala @@ -10,38 +10,43 @@ import io.shiftleft.codepropertygraph.Cpg import java.nio.file.Path import scala.util.Try -case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator { - private lazy val command: Path = if (isWin) rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") - private var pyConfig: Option[Py2CpgOnFileSystemConfig] = None - - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = { - val arguments = Seq(inputPath, "-o", outputPath) ++ config.cmdLineParams - pyConfig = X2Cpg.parseCommandLine(arguments.toArray, NewMain.getCmdLineParser, Py2CpgOnFileSystemConfig()) - runShellCommand(command.toString, arguments).map(_ => outputPath) - } - - override def isAvailable: Boolean = - command.toFile.exists - - override def applyPostProcessingPasses(cpg: Cpg): Cpg = { - new ImportsPass(cpg).createAndApply() - new ImportResolverPass(cpg).createAndApply() - new DynamicTypeHintFullNamePass(cpg).createAndApply() - new PythonInheritanceNamePass(cpg).createAndApply() - val typeRecoveryConfig = pyConfig match - case Some(config) => XTypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) - case None => XTypeRecoveryConfig() - new PythonTypeRecoveryPass(cpg, typeRecoveryConfig).createAndApply() - new PythonTypeHintCallLinker(cpg).createAndApply() - - // Some of passes above create new methods, so, we - // need to run the ASTLinkerPass one more time - new AstLinkerPass(cpg).createAndApply() - - cpg - } - - override def isJvmBased = true -} +case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: + private lazy val command: Path = + if isWin then rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") + private var pyConfig: Option[Py2CpgOnFileSystemConfig] = None + + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "-o", outputPath) ++ config.cmdLineParams + pyConfig = X2Cpg.parseCommandLine( + arguments.toArray, + NewMain.getCmdLineParser, + Py2CpgOnFileSystemConfig() + ) + runShellCommand(command.toString, arguments).map(_ => outputPath) + + override def isAvailable: Boolean = + command.toFile.exists + + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + new ImportsPass(cpg).createAndApply() + new ImportResolverPass(cpg).createAndApply() + new DynamicTypeHintFullNamePass(cpg).createAndApply() + new PythonInheritanceNamePass(cpg).createAndApply() + val typeRecoveryConfig = pyConfig match + case Some(config) => + XTypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) + case None => XTypeRecoveryConfig() + new PythonTypeRecoveryPass(cpg, typeRecoveryConfig).createAndApply() + new PythonTypeHintCallLinker(cpg).createAndApply() + + // Some of passes above create new methods, so, we + // need to run the ASTLinkerPass one more time + new AstLinkerPass(cpg).createAndApply() + + cpg + + override def isJvmBased = true +end PythonSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala index 75145626..28791221 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala @@ -6,102 +6,97 @@ import better.files.File import java.nio.file.Path import scala.collection.mutable -package object cpgcreation { - - /** For a given language, return CPG generator script - */ - def cpgGeneratorForLanguage( - language: String, - config: FrontendConfig, - rootPath: Path, - args: List[String] - ): Option[CpgGenerator] = { - lazy val conf = config.withArgs(args) - language.toUpperCase() match { - case Languages.C | Languages.NEWC | Languages.JAVA | Languages.JAVASRC | Languages.JSSRC | Languages.JAVASCRIPT | - Languages.PYTHON | Languages.PYTHONSRC => - Some(AtomGenerator(conf, rootPath, language)) - case _ => None - } - } - - /** Heuristically determines language by inspecting file/dir at path. - */ - def guessLanguage(path: String): Option[String] = { - val file = File(path) - if (file.isDirectory) { - guessMajorityLanguageInDir(file) - } else { - guessLanguageForRegularFile(file) - } - } - - /** Guess the main language for an entire directory (e.g. a whole project), based on a group count of all individual - * files. Rationale: many projects contain files from different languages, but most often one language is standing - * out in numbers. - */ - private def guessMajorityLanguageInDir(directory: File): Option[String] = { - assert(directory.isDirectory, s"$directory must be a directory, but wasn't") - val groupCount = mutable.Map.empty[String, Int].withDefaultValue(0) - - for { - file <- directory.listRecursively - if file.isRegularFile - guessedLanguage <- guessLanguageForRegularFile(file) - } { - val oldValue = groupCount(guessedLanguage) - groupCount.update(guessedLanguage, oldValue + 1) - } - - groupCount.toSeq.sortBy(_._2).lastOption.map(_._1) - } - - private def isJavaBinary(filename: String): Boolean = - Seq(".jar", ".war", ".ear", ".apk").exists(filename.endsWith) - - private def isCsharpFile(filename: String): Boolean = - Seq(".csproj", ".cs").exists(filename.endsWith) - - private def isGoFile(filename: String): Boolean = - filename.endsWith(".go") || Set("gopkg.lock", "gopkg.toml", "go.mod", "go.sum").contains(filename) - - private def isLlvmFile(filename: String): Boolean = - Seq(".bc", ".ll").exists(filename.endsWith) - - private def isJsFile(filename: String): Boolean = - Seq(".js", ".ts", ".jsx", ".tsx").exists(filename.endsWith) || filename == "package.json" - - /** check if given filename looks like it might be a C/CPP source or header file mostly copied from - * io.appthreat.c2cpg.parser.FileDefaults - */ - private def isCFile(filename: String): Boolean = - Seq(".c", ".cc", ".cpp", ".h", ".hpp", ".hh", ".ccm", ".cxxm", ".c++m").exists(filename.endsWith) - - private def isYamlFile(filename: String): Boolean = - Seq(".yml", ".yaml").exists(filename.endsWith) - - private def isBomFile(filename: String): Boolean = - Seq("bom.json", ".cdx.json").exists(filename.endsWith) - - private def guessLanguageForRegularFile(file: File): Option[String] = { - file.name.toLowerCase match { - case f if isJavaBinary(f) => Some(Languages.JAVA) - case f if isCsharpFile(f) => Some(Languages.CSHARP) - case f if isGoFile(f) => Some(Languages.GOLANG) - case f if isJsFile(f) => Some(Languages.JSSRC) - case f if f.endsWith(".java") => Some(Languages.JAVASRC) - case f if f.endsWith(".class") => Some(Languages.JAVA) - case f if f.endsWith(".kt") => Some(Languages.KOTLIN) - case f if f.endsWith(".php") => Some(Languages.PHP) - case f if f.endsWith(".py") => Some(Languages.PYTHONSRC) - case f if f.endsWith(".rb") => Some(Languages.RUBYSRC) - case f if isLlvmFile(f) => Some(Languages.LLVM) - case f if isCFile(f) => Some(Languages.NEWC) - case f if isBomFile(f) => Option("BOM") - case f if f.endsWith(".json") => Option("JSON") - case f if isYamlFile(f) => Option("YAML") - case _ => None - } - } - -} +package object cpgcreation: + + /** For a given language, return CPG generator script + */ + def cpgGeneratorForLanguage( + language: String, + config: FrontendConfig, + rootPath: Path, + args: List[String] + ): Option[CpgGenerator] = + lazy val conf = config.withArgs(args) + language.toUpperCase() match + case Languages.C | Languages.NEWC | Languages.JAVA | Languages.JAVASRC | Languages.JSSRC | Languages.JAVASCRIPT | + Languages.PYTHON | Languages.PYTHONSRC => + Some(AtomGenerator(conf, rootPath, language)) + case _ => None + + /** Heuristically determines language by inspecting file/dir at path. + */ + def guessLanguage(path: String): Option[String] = + val file = File(path) + if file.isDirectory then + guessMajorityLanguageInDir(file) + else + guessLanguageForRegularFile(file) + + /** Guess the main language for an entire directory (e.g. a whole project), based on a group + * count of all individual files. Rationale: many projects contain files from different + * languages, but most often one language is standing out in numbers. + */ + private def guessMajorityLanguageInDir(directory: File): Option[String] = + assert(directory.isDirectory, s"$directory must be a directory, but wasn't") + val groupCount = mutable.Map.empty[String, Int].withDefaultValue(0) + + for + file <- directory.listRecursively + if file.isRegularFile + guessedLanguage <- guessLanguageForRegularFile(file) + do + val oldValue = groupCount(guessedLanguage) + groupCount.update(guessedLanguage, oldValue + 1) + + groupCount.toSeq.sortBy(_._2).lastOption.map(_._1) + + private def isJavaBinary(filename: String): Boolean = + Seq(".jar", ".war", ".ear", ".apk").exists(filename.endsWith) + + private def isCsharpFile(filename: String): Boolean = + Seq(".csproj", ".cs").exists(filename.endsWith) + + private def isGoFile(filename: String): Boolean = + filename.endsWith(".go") || Set("gopkg.lock", "gopkg.toml", "go.mod", "go.sum").contains( + filename + ) + + private def isLlvmFile(filename: String): Boolean = + Seq(".bc", ".ll").exists(filename.endsWith) + + private def isJsFile(filename: String): Boolean = + Seq(".js", ".ts", ".jsx", ".tsx").exists(filename.endsWith) || filename == "package.json" + + /** check if given filename looks like it might be a C/CPP source or header file mostly copied + * from io.appthreat.c2cpg.parser.FileDefaults + */ + private def isCFile(filename: String): Boolean = + Seq(".c", ".cc", ".cpp", ".h", ".hpp", ".hh", ".ccm", ".cxxm", ".c++m").exists( + filename.endsWith + ) + + private def isYamlFile(filename: String): Boolean = + Seq(".yml", ".yaml").exists(filename.endsWith) + + private def isBomFile(filename: String): Boolean = + Seq("bom.json", ".cdx.json").exists(filename.endsWith) + + private def guessLanguageForRegularFile(file: File): Option[String] = + file.name.toLowerCase match + case f if isJavaBinary(f) => Some(Languages.JAVA) + case f if isCsharpFile(f) => Some(Languages.CSHARP) + case f if isGoFile(f) => Some(Languages.GOLANG) + case f if isJsFile(f) => Some(Languages.JSSRC) + case f if f.endsWith(".java") => Some(Languages.JAVASRC) + case f if f.endsWith(".class") => Some(Languages.JAVA) + case f if f.endsWith(".kt") => Some(Languages.KOTLIN) + case f if f.endsWith(".php") => Some(Languages.PHP) + case f if f.endsWith(".py") => Some(Languages.PYTHONSRC) + case f if f.endsWith(".rb") => Some(Languages.RUBYSRC) + case f if isLlvmFile(f) => Some(Languages.LLVM) + case f if isCFile(f) => Some(Languages.NEWC) + case f if isBomFile(f) => Option("BOM") + case f if f.endsWith(".json") => Option("JSON") + case f if isYamlFile(f) => Option("YAML") + case _ => None +end cpgcreation diff --git a/console/src/main/scala/io/appthreat/console/package.scala b/console/src/main/scala/io/appthreat/console/package.scala index 31bf79a9..e1d4a775 100644 --- a/console/src/main/scala/io/appthreat/console/package.scala +++ b/console/src/main/scala/io/appthreat/console/package.scala @@ -4,40 +4,39 @@ import replpp.Operators.* import replpp.Colors // TODO remove any time after the end of 2023 - this is completely deprecated -package object console { - - implicit class UnixUtils[A](content: Iterable[A]) { - given Colors = Colors.Default - - /** Iterate over left hand side operand and write to file. Think of it as the Ocular version of the Unix `>` shell - * redirection. - */ - @deprecated("please use `#>` instead", "2.0.45 (August 2023)") - def |>(outfile: String): Unit = - content #> outfile - - /** Iterate over left hand side operand and append to file. Think of it as the Ocular version of the Unix `>>` shell - * redirection. - */ - @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") - def |>>(outfile: String): Unit = - content #>> outfile - } - - implicit class StringOps(value: String) { - given Colors = Colors.Default - - /** Pipe string to file. Think of it as the Ocular version of the Unix `>` shell redirection. - */ - @deprecated("please use `#>` instead", "2.0.45 (August 2023)") - def |>(outfile: String): Unit = - value #> outfile - - /** Append string to file. Think of it as the Ocular version of the Unix `>>` shell redirection. - */ - @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") - def |>>(outfile: String): Unit = - value #>> outfile - } - -} +package object console: + + implicit class UnixUtils[A](content: Iterable[A]): + given Colors = Colors.Default + + /** Iterate over left hand side operand and write to file. Think of it as the Ocular version + * of the Unix `>` shell redirection. + */ + @deprecated("please use `#>` instead", "2.0.45 (August 2023)") + def |>(outfile: String): Unit = + content #> outfile + + /** Iterate over left hand side operand and append to file. Think of it as the Ocular + * version of the Unix `>>` shell redirection. + */ + @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") + def |>>(outfile: String): Unit = + content #>> outfile + + implicit class StringOps(value: String): + given Colors = Colors.Default + + /** Pipe string to file. Think of it as the Ocular version of the Unix `>` shell + * redirection. + */ + @deprecated("please use `#>` instead", "2.0.45 (August 2023)") + def |>(outfile: String): Unit = + value #> outfile + + /** Append string to file. Think of it as the Ocular version of the Unix `>>` shell + * redirection. + */ + @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") + def |>>(outfile: String): Unit = + value #>> outfile +end console diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala index 72600286..86a21893 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala @@ -1,16 +1,15 @@ package io.appthreat.console.workspacehandling -import better.files.Dsl._ +import better.files.Dsl.* import better.files.File import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.semanticcpg.Overlays import java.nio.file.Path -object Project { - val workCpgFileName = "cpg.bin.tmp" - val persistentCpgFileName = "cpg.bin" -} +object Project: + val workCpgFileName = "cpg.bin.tmp" + val persistentCpgFileName = "cpg.bin" case class ProjectFile(inputPath: String, name: String) @@ -19,50 +18,44 @@ case class ProjectFile(inputPath: String, name: String) * @param cpg * reference to loaded CPG or None, if the CPG is not loaded */ -case class Project(projectFile: ProjectFile, var path: Path, var cpg: Option[Cpg] = None) { +case class Project(projectFile: ProjectFile, var path: Path, var cpg: Option[Cpg] = None): - import Project._ + import Project.* - def name: String = projectFile.name + def name: String = projectFile.name - def inputPath: String = projectFile.inputPath + def inputPath: String = projectFile.inputPath - def isOpen: Boolean = cpg.isDefined + def isOpen: Boolean = cpg.isDefined - def appliedOverlays: Seq[String] = { - cpg.map(Overlays.appliedOverlays).getOrElse(Nil) - } + def appliedOverlays: Seq[String] = + cpg.map(Overlays.appliedOverlays).getOrElse(Nil) - def availableOverlays: List[String] = { - File(path.resolve("overlays")).list.map(_.name).toList - } + def availableOverlays: List[String] = + File(path.resolve("overlays")).list.map(_.name).toList - def overlayDirs: Seq[File] = { - val overlayDir = File(path.resolve("overlays")) - appliedOverlays.map(o => overlayDir / o) - } + def overlayDirs: Seq[File] = + val overlayDir = File(path.resolve("overlays")) + appliedOverlays.map(o => overlayDir / o) - override def toString: String = - toTableRow.mkString("\t") + override def toString: String = + toTableRow.mkString("\t") - def toTableRow: List[String] = { - val cpgLoaded = cpg.isDefined - val overlays = availableOverlays.mkString(",") - val inputPath = projectFile.inputPath - List(name, overlays, inputPath, cpgLoaded.toString) - } + def toTableRow: List[String] = + val cpgLoaded = cpg.isDefined + val overlays = availableOverlays.mkString(",") + val inputPath = projectFile.inputPath + List(name, overlays, inputPath, cpgLoaded.toString) - /** Close project if it is open and do nothing otherwise. - */ - def close: Project = { - cpg.foreach { c => - c.close() - val workingCopy = path.resolve(workCpgFileName) - val persistent = path.resolve(persistentCpgFileName) - cp(workingCopy, persistent) - } - cpg = None - this - } - -} + /** Close project if it is open and do nothing otherwise. + */ + def close: Project = + cpg.foreach { c => + c.close() + val workingCopy = path.resolve(workCpgFileName) + val persistent = path.resolve(persistentCpgFileName) + cp(workingCopy, persistent) + } + cpg = None + this +end Project diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala index cabde4bb..1d96fd80 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala @@ -4,24 +4,27 @@ import overflowdb.traversal.help.Table import scala.collection.mutable.ListBuffer -/** Create a workspace from a list of projects. Workspace is a passive object that is managed by WorkspaceManager +/** Create a workspace from a list of projects. Workspace is a passive object that is managed by + * WorkspaceManager * @param projects * list of projects present in this workspace */ -class Workspace[ProjectType <: Project](var projects: ListBuffer[ProjectType]) { +class Workspace[ProjectType <: Project](var projects: ListBuffer[ProjectType]): - /** Returns total number of projects in this workspace - */ - def numberOfProjects: Int = projects.size + /** Returns total number of projects in this workspace + */ + def numberOfProjects: Int = projects.size - /** Provide a human-readable overview of the workspace - */ - override def toString: String = { - if (projects.isEmpty) { - System.err.println("The workpace is empty. Use `importCode` or `importAtom` to populate it") - "empty" - } else { - """ + /** Provide a human-readable overview of the workspace + */ + override def toString: String = + if projects.isEmpty then + System.err.println( + "The workpace is empty. Use `importCode` or `importAtom` to populate it" + ) + "empty" + else + """ |Overview of all projects present in your workspace. You can use `open` and `close` |to load and unload projects respectively. `cpgs` allows you to query all projects |at once. `cpg` points to the Code Property Graph of the *selected* project, which is @@ -30,12 +33,8 @@ class Workspace[ProjectType <: Project](var projects: ListBuffer[ProjectType]) { | | Type `run` to add additional overlays to code property graphs |""".stripMargin - "\n" + Table( - columnNames = List("name", "overlays", "inputPath", "open"), - rows = projects.map(_.toTableRow).toList - ).render - } - - } - -} + "\n" + Table( + columnNames = List("name", "overlays", "inputPath", "open"), + rows = projects.map(_.toTableRow).toList + ).render +end Workspace diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala index 55dce6e7..36ea4f8a 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala @@ -9,48 +9,42 @@ import scala.util.{Failure, Success, Try} /** This component loads a workspace from disk and creates a corresponding `Workspace` object. */ -abstract class WorkspaceLoader[ProjectType <: Project] { - - /** Initialize workspace from a directory - * @param path - * path to the directory - */ - def load(path: String): Workspace[ProjectType] = { - val dirFile = File(path) - val dirPath = dirFile.path.toAbsolutePath - - mkdirs(dirFile) - new Workspace(ListBuffer.from(loadProjectsFromFs(dirPath))) - } - - private def loadProjectsFromFs(cpgsPath: Path): LazyList[ProjectType] = { - cpgsPath.toFile.listFiles - .filter(_.isDirectory) - .to(LazyList) - .flatMap(f => loadProject(f.toPath)) - } - - def loadProject(path: Path): Option[ProjectType] = { - Try { - val projectFile = readProjectFile(path) - createProject(projectFile, path) - } match { - case Success(v) => Some(v) - case Failure(e) => - System.err.println(s"Error loading project at $path - skipping: ") - e.printStackTrace - None - } - } - - def createProject(projectFile: ProjectFile, path: Path): ProjectType - - private val PROJECTFILE_NAME = "project.json" - - private def readProjectFile(projectDirName: Path): ProjectFile = { - // TODO see `writeProjectFile` - val data = ujson.read(projectDirName.resolve(PROJECTFILE_NAME)) - ProjectFile(data("inputPath").str, data("name").str) - } - -} +abstract class WorkspaceLoader[ProjectType <: Project]: + + /** Initialize workspace from a directory + * @param path + * path to the directory + */ + def load(path: String): Workspace[ProjectType] = + val dirFile = File(path) + val dirPath = dirFile.path.toAbsolutePath + + mkdirs(dirFile) + new Workspace(ListBuffer.from(loadProjectsFromFs(dirPath))) + + private def loadProjectsFromFs(cpgsPath: Path): LazyList[ProjectType] = + cpgsPath.toFile.listFiles + .filter(_.isDirectory) + .to(LazyList) + .flatMap(f => loadProject(f.toPath)) + + def loadProject(path: Path): Option[ProjectType] = + Try { + val projectFile = readProjectFile(path) + createProject(projectFile, path) + } match + case Success(v) => Some(v) + case Failure(e) => + System.err.println(s"Error loading project at $path - skipping: ") + e.printStackTrace + None + + def createProject(projectFile: ProjectFile, path: Path): ProjectType + + private val PROJECTFILE_NAME = "project.json" + + private def readProjectFile(projectDirName: Path): ProjectFile = + // TODO see `writeProjectFile` + val data = ujson.read(projectDirName.resolve(PROJECTFILE_NAME)) + ProjectFile(data("inputPath").str, data("name").str) +end WorkspaceLoader diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala index a750fdc6..bd55731a 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala @@ -15,406 +15,371 @@ import java.nio.file.Path import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} -object DefaultLoader extends WorkspaceLoader[Project] { - override def createProject(projectFile: ProjectFile, path: Path): Project = { - Project(projectFile, path) - } -} +object DefaultLoader extends WorkspaceLoader[Project]: + override def createProject(projectFile: ProjectFile, path: Path): Project = + Project(projectFile, path) -/** WorkspaceManager: a component, which loads and maintains the list of projects made accessible via Ocular/Joern. +/** WorkspaceManager: a component, which loads and maintains the list of projects made accessible + * via Ocular/Joern. * * @param path * path to to workspace. */ -class WorkspaceManager[ProjectType <: Project](path: String, loader: WorkspaceLoader[ProjectType] = DefaultLoader) - extends Reporting { - - def getPath: String = path - - import WorkspaceManager._ - - /** The workspace managed by this WorkspaceManager - */ - private var workspace: Workspace[ProjectType] = loader.load(path) - private val dirPath = File(path).path.toAbsolutePath - - private val LEGACY_BASE_CPG_FILENAME = "app.atom" - private val OVERLAY_DIR_NAME = "overlays" - - /** Create project for code stored at `inputPath` with the project name `name`. If `name` is empty, the project name - * is derived from `inputPath`. If a project for this `name` already exists, it is deleted from the workspace first. - * If no file or directory exists at `inputPath`, then no project is created. Returns the path to the project - * directory as an optional String, and None if there was an error. - */ - def createProject(inputPath: String, projectName: String): Option[Path] = { - Some(File(inputPath)).filter(_.exists).map { _ => - val pathToProject = projectNameToDir(projectName) +class WorkspaceManager[ProjectType <: Project]( + path: String, + loader: WorkspaceLoader[ProjectType] = DefaultLoader +) extends Reporting: + + def getPath: String = path + + import WorkspaceManager.* + + /** The workspace managed by this WorkspaceManager + */ + private var workspace: Workspace[ProjectType] = loader.load(path) + private val dirPath = File(path).path.toAbsolutePath + + private val LEGACY_BASE_CPG_FILENAME = "app.atom" + private val OVERLAY_DIR_NAME = "overlays" + + /** Create project for code stored at `inputPath` with the project name `name`. If `name` is + * empty, the project name is derived from `inputPath`. If a project for this `name` already + * exists, it is deleted from the workspace first. If no file or directory exists at + * `inputPath`, then no project is created. Returns the path to the project directory as an + * optional String, and None if there was an error. + */ + def createProject(inputPath: String, projectName: String): Option[Path] = + Some(File(inputPath)).filter(_.exists).map { _ => + val pathToProject = projectNameToDir(projectName) + + if project(projectName).isDefined then + removeProject(projectName) + + createProjectDirectory(inputPath, name = projectName) + loader.loadProject(pathToProject).foreach(addProjectToProjectsList) + pathToProject + } - if (project(projectName).isDefined) { - removeProject(projectName) + /** Remove project named `name` from disk + * @param name + * name of the project + */ + def removeProject(name: String): Unit = + closeProject(name) + removeProjectFromList(name) + File(projectNameToDir(name)).delete() + + /** Create new project directory containing project file. + */ + private def createProjectDirectory(inputPath: String, name: String): Unit = + val dirPath = projectNameToDir(name) + mkdirs(dirPath) + val absoluteInputPath = File(inputPath).path.toAbsolutePath.toString + mkdirs(File(overlayDir(name))) + val projectFile = ProjectFile(absoluteInputPath, name) + writeProjectFile(projectFile, dirPath) + touch(dirPath.toString / "cpg.bin") + + /** Write the project's `project.json`, a JSON file that holds meta information. + */ + private def writeProjectFile(projectFile: ProjectFile, dirPath: Path): File = + // TODO proguard and json4s don't play along. We actually want to + // serialize the case class ProjectFile here, but it comes out + // empty. This code will be moved to `codepropertgraph` at + // which point serialization should work. + // val content = jsonWrite(projectFile) + implicit val formats: DefaultFormats.type = DefaultFormats + val PROJECTFILE_NAME = "project.json" + val content = + jsonWrite(Map("inputPath" -> projectFile.inputPath, "name" -> projectFile.name)) + File(dirPath.resolve(PROJECTFILE_NAME)).write(content) + + /** Delete the workspace from disk, then initialize it again. + */ + def reset: Unit = + Try(cpg.close()) + deleteWorkspace() + workspace = loader.load(path) + + private def deleteWorkspace(): Unit = + if dirPath == null || dirPath.toString == "" then + throw new RuntimeException("dirPath is not set") + val dirFile = better.files.File(dirPath.toAbsolutePath.toString) + if !dirFile.exists then + throw new RuntimeException(s"Directory ${dirFile.toString} does not exist") + + dirFile.delete() + + /** Return the number of projects currently present in this workspace. + */ + def numberOfProjects: Int = workspace.projects.size + + def projects: List[Project] = workspace.projects.toList + + def project(name: String): Option[Project] = workspace.projects.find(_.name == name) + + override def toString: String = workspace.toString + + private def getProjectByPath(projectPath: Path): Option[Project] = + workspace.projects.find(_.path.toAbsolutePath == projectPath.toAbsolutePath) + + /** A sorted list of all loaded CPGs + */ + def loadedCpgs: List[Cpg] = workspace.projects.flatMap(_.cpg).toList + + /** Indicates whether a workspace record exists for @inputPath. + */ + def projectExists(inputPath: String): Boolean = projectDir(inputPath).toFile.exists + + /** Indicates whether a base CPG exists for @inputPath. + */ + def cpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = + val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME + projectExists(inputPath) && + projectDir(inputPath).resolve(baseFileName).toFile.exists + + /** Overlay directory for CPG with given @inputPath + */ + def overlayDir(inputPath: String): String = + projectDir(inputPath).resolve(OVERLAY_DIR_NAME).toString + + def overlayDirByProjectName(name: String): String = + projectNameToDir(name).resolve(OVERLAY_DIR_NAME).toString + + /** Filename for the base CPG for @inputPath + */ + private def baseCpgFilename(inputPath: String, isLegacy: Boolean = false): String = + val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME + projectDir(inputPath).resolve(baseFileName).toString + + /** The safe directory name for a given input file that can be used to store its base CPG, along + * with all overlays. This method returns a directory name regardless of whether the directory + * exists or not. + */ + private def projectDir(inputPath: String): Path = + val filename = File(inputPath).path.getFileName + if filename == null then + throw new RuntimeException("invalid input path: " + inputPath) + dirPath + .resolve(URLEncoder.encode(filename.toString, DefaultCharset.toString)) + .toAbsolutePath + + private def projectNameToDir(name: String): Path = + dirPath + .resolve(URLEncoder.encode(name, DefaultCharset.toString)) + .toAbsolutePath + + /** Record for the given name. None if it is not in the workspace. + */ + private def projectByName(name: String): Option[ProjectType] = + workspace.projects.find(r => r.name == name) + + /** Workspace record for the CPG, or none, if the CPG is not in the workspace + */ + def projectByCpg(baseCpg: Cpg): Option[ProjectType] = + workspace.projects.find(_.cpg.contains(baseCpg)) + + def projectExistsForCpg(baseCpg: Cpg): Boolean = projectByCpg(baseCpg).isDefined + + def getNextOverlayDirName(baseCpg: Cpg, overlayName: String): String = + val project = projectByCpg(baseCpg).get + val overlayDirectory = File(overlayDirByProjectName(project.name)) + + val overlayFile = overlayDirectory.path + .resolve(overlayName) + .toFile + + overlayFile.getAbsolutePath + + /** Obtain the cpg that was last loaded. Throws a runtime exception if no CPG has been loaded. + */ + def cpg: Cpg = + val project = workspace.projects.lastOption + project match + case Some(p) => + p.cpg match + case Some(value) => value + case None => + throw new RuntimeException( + s"No Atom loaded for exploration - try importing one using `importCode|importAtom`" + ) + case None => throw new RuntimeException("No projects loaded") + + /** Set active project to project with name `name`. If a project with this name does not exist, + * does nothing. + */ + def setActiveProject(name: String): Option[ProjectType] = + val project = projectByName(name) + if project.isEmpty then + System.err.println(s"Error: project with name $name does not exist") + None + else + removeProjectFromList(name).map { p => + addProjectToProjectsList(p) + p + } + + /** Retrieve the currently active project. If no project is active, None is returned. + */ + def getActiveProject: Option[Project] = + workspace.projects.lastOption + + /** Open project by name and return it. If a project with this name does not exist, None is + * returned. If the CPG of this project is loaded, it is unloaded first and then reloaded. + * Returns project or None on error. + * + * @param name + * of the project to load + * @param loader + * function to perform CPG loading. This parameter only exists for testing purposes. + */ + def openProject( + name: String, + loader: String => Option[Cpg] = { x => + loadCpgRaw(x) } - - createProjectDirectory(inputPath, name = projectName) - loader.loadProject(pathToProject).foreach(addProjectToProjectsList) - pathToProject - } - } - - /** Remove project named `name` from disk - * @param name - * name of the project - */ - def removeProject(name: String): Unit = { - closeProject(name) - removeProjectFromList(name) - File(projectNameToDir(name)).delete() - } - - /** Create new project directory containing project file. - */ - private def createProjectDirectory(inputPath: String, name: String): Unit = { - val dirPath = projectNameToDir(name) - mkdirs(dirPath) - val absoluteInputPath = File(inputPath).path.toAbsolutePath.toString - mkdirs(File(overlayDir(name))) - val projectFile = ProjectFile(absoluteInputPath, name) - writeProjectFile(projectFile, dirPath) - touch(dirPath.toString / "cpg.bin") - } - - /** Write the project's `project.json`, a JSON file that holds meta information. - */ - private def writeProjectFile(projectFile: ProjectFile, dirPath: Path): File = { - // TODO proguard and json4s don't play along. We actually want to - // serialize the case class ProjectFile here, but it comes out - // empty. This code will be moved to `codepropertgraph` at - // which point serialization should work. - // val content = jsonWrite(projectFile) - implicit val formats: DefaultFormats.type = DefaultFormats - val PROJECTFILE_NAME = "project.json" - val content = jsonWrite(Map("inputPath" -> projectFile.inputPath, "name" -> projectFile.name)) - File(dirPath.resolve(PROJECTFILE_NAME)).write(content) - } - - /** Delete the workspace from disk, then initialize it again. - */ - def reset: Unit = { - Try(cpg.close()) - deleteWorkspace() - workspace = loader.load(path) - } - - private def deleteWorkspace(): Unit = { - if (dirPath == null || dirPath.toString == "") { - throw new RuntimeException("dirPath is not set") - } - val dirFile = better.files.File(dirPath.toAbsolutePath.toString) - if (!dirFile.exists) { - throw new RuntimeException(s"Directory ${dirFile.toString} does not exist") - } - - dirFile.delete() - } - - /** Return the number of projects currently present in this workspace. - */ - def numberOfProjects: Int = workspace.projects.size - - def projects: List[Project] = workspace.projects.toList - - def project(name: String): Option[Project] = workspace.projects.find(_.name == name) - - override def toString: String = workspace.toString - - private def getProjectByPath(projectPath: Path): Option[Project] = - workspace.projects.find(_.path.toAbsolutePath == projectPath.toAbsolutePath) - - /** A sorted list of all loaded CPGs - */ - def loadedCpgs: List[Cpg] = workspace.projects.flatMap(_.cpg).toList - - /** Indicates whether a workspace record exists for @inputPath. - */ - def projectExists(inputPath: String): Boolean = projectDir(inputPath).toFile.exists - - /** Indicates whether a base CPG exists for @inputPath. - */ - def cpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = { - val baseFileName = if (isLegacy) LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME - projectExists(inputPath) && - projectDir(inputPath).resolve(baseFileName).toFile.exists - } - - /** Overlay directory for CPG with given @inputPath - */ - def overlayDir(inputPath: String): String = { - projectDir(inputPath).resolve(OVERLAY_DIR_NAME).toString - } - - def overlayDirByProjectName(name: String): String = { - projectNameToDir(name).resolve(OVERLAY_DIR_NAME).toString - } - - /** Filename for the base CPG for @inputPath - */ - private def baseCpgFilename(inputPath: String, isLegacy: Boolean = false): String = { - val baseFileName = if (isLegacy) LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME - projectDir(inputPath).resolve(baseFileName).toString - } - - /** The safe directory name for a given input file that can be used to store its base CPG, along with all overlays. - * This method returns a directory name regardless of whether the directory exists or not. - */ - private def projectDir(inputPath: String): Path = { - val filename = File(inputPath).path.getFileName - if (filename == null) { - throw new RuntimeException("invalid input path: " + inputPath) - } - dirPath - .resolve(URLEncoder.encode(filename.toString, DefaultCharset.toString)) - .toAbsolutePath - } - - private def projectNameToDir(name: String): Path = { - dirPath - .resolve(URLEncoder.encode(name, DefaultCharset.toString)) - .toAbsolutePath - } - - /** Record for the given name. None if it is not in the workspace. - */ - private def projectByName(name: String): Option[ProjectType] = { - workspace.projects.find(r => r.name == name) - } - - /** Workspace record for the CPG, or none, if the CPG is not in the workspace - */ - def projectByCpg(baseCpg: Cpg): Option[ProjectType] = - workspace.projects.find(_.cpg.contains(baseCpg)) - - def projectExistsForCpg(baseCpg: Cpg): Boolean = projectByCpg(baseCpg).isDefined - - def getNextOverlayDirName(baseCpg: Cpg, overlayName: String): String = { - val project = projectByCpg(baseCpg).get - val overlayDirectory = File(overlayDirByProjectName(project.name)) - - val overlayFile = overlayDirectory.path - .resolve(overlayName) - .toFile - - overlayFile.getAbsolutePath - } - - /** Obtain the cpg that was last loaded. Throws a runtime exception if no CPG has been loaded. - */ - def cpg: Cpg = { - val project = workspace.projects.lastOption - project match { - case Some(p) => - p.cpg match { - case Some(value) => value - case None => - throw new RuntimeException( - s"No Atom loaded for exploration - try importing one using `importCode|importAtom`" + ): Option[Project] = + if !projectExists(name) then + report( + s"Project does not exist in workspace. Try `importCode/importAtom(inputPath)` to create it" ) + None + else if !File(baseCpgFilename(name)).exists then + report(s"CPG for project $name does not exist at ${baseCpgFilename(name)}, bailing out") + None + else if project(name).exists(_.cpg.isDefined) then + setActiveProject(name) + project(name) + else + val cpgFilename = baseCpgFilename(name) + val cpgFile = File(cpgFilename) + val workingCopyPath = projectDir(name).resolve(Project.workCpgFileName) + val workingCopyName = workingCopyPath.toAbsolutePath.toString + cp(cpgFile, workingCopyPath) + + val result = + val newCpg = loader(workingCopyName) + val projectPath = File(workingCopyName).parent.path + newCpg.flatMap { c => + unloadCpgIfExists(name) + setCpgForProject(c, projectPath) + projectByCpg(c) + } + result + + /** Free up resources occupied by this project but do not remove project from disk. + */ + def closeProject(name: String): Option[Project] = + projectByName(name).map(_.close) + + /** Set CPG for existing project. It is assumed that the CPG is loaded. + */ + private def setCpgForProject(newCpg: Cpg, projectPath: Path): Unit = + val project = getProjectByPath(projectPath) + project match + case Some(p) => + p.cpg = Some(newCpg) + setActiveProject(p.name) + case None => + System.err.println( + s"Error setting CPG for non-existing/unloaded project at $projectPath" + ) + + private def loadCpgRaw(cpgFilename: String): Option[Cpg] = + Try { + val odbConfig = Config.withDefaults.withStorageLocation(cpgFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + val newCpg = CpgLoader.loadFromOverflowDb(config) + CpgLoader.createIndexes(newCpg) + newCpg + } match + case Success(v) => Some(v) + case Failure(ex) => + System.err.println("Error loading CPG") + System.err.println(ex) + None + + private def addProjectToProjectsList(project: ProjectType): ListBuffer[ProjectType] = + workspace.projects += project + + def unloadCpgByProjectName(name: String): Unit = + projectByName(name).foreach { record => + record.cpg.foreach(_.close) + record.cpg = None } - case None => throw new RuntimeException("No projects loaded") - } - } - - /** Set active project to project with name `name`. If a project with this name does not exist, does nothing. - */ - def setActiveProject(name: String): Option[ProjectType] = { - val project = projectByName(name) - if (project.isEmpty) { - System.err.println(s"Error: project with name $name does not exist") - None - } else { - removeProjectFromList(name).map { p => - addProjectToProjectsList(p) - p - } - } - } - - /** Retrieve the currently active project. If no project is active, None is returned. - */ - def getActiveProject: Option[Project] = { - workspace.projects.lastOption - } - - /** Open project by name and return it. If a project with this name does not exist, None is returned. If the CPG of - * this project is loaded, it is unloaded first and then reloaded. Returns project or None on error. - * - * @param name - * of the project to load - * @param loader - * function to perform CPG loading. This parameter only exists for testing purposes. - */ - def openProject( - name: String, - loader: String => Option[Cpg] = { x => - loadCpgRaw(x) - } - ): Option[Project] = { - if (!projectExists(name)) { - report(s"Project does not exist in workspace. Try `importCode/importAtom(inputPath)` to create it") - None - } else if (!File(baseCpgFilename(name)).exists) { - report(s"CPG for project $name does not exist at ${baseCpgFilename(name)}, bailing out") - None - } else if (project(name).exists(_.cpg.isDefined)) { - setActiveProject(name) - project(name) - } else { - val cpgFilename = baseCpgFilename(name) - val cpgFile = File(cpgFilename) - val workingCopyPath = projectDir(name).resolve(Project.workCpgFileName) - val workingCopyName = workingCopyPath.toAbsolutePath.toString - cp(cpgFile, workingCopyPath) - - val result = { - val newCpg = loader(workingCopyName) - val projectPath = File(workingCopyName).parent.path - newCpg.flatMap { c => - unloadCpgIfExists(name) - setCpgForProject(c, projectPath) - projectByCpg(c) - } - } - result - } - } - - /** Free up resources occupied by this project but do not remove project from disk. - */ - def closeProject(name: String): Option[Project] = - projectByName(name).map(_.close) - - /** Set CPG for existing project. It is assumed that the CPG is loaded. - */ - private def setCpgForProject(newCpg: Cpg, projectPath: Path): Unit = { - val project = getProjectByPath(projectPath) - project match { - case Some(p) => - p.cpg = Some(newCpg) - setActiveProject(p.name) - case None => - System.err.println(s"Error setting CPG for non-existing/unloaded project at $projectPath") - } - } - - private def loadCpgRaw(cpgFilename: String): Option[Cpg] = { - Try { - val odbConfig = Config.withDefaults.withStorageLocation(cpgFilename) - val config = - CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - val newCpg = CpgLoader.loadFromOverflowDb(config) - CpgLoader.createIndexes(newCpg) - newCpg - } match { - case Success(v) => Some(v) - case Failure(ex) => - System.err.println("Error loading CPG") - System.err.println(ex) - None - } - } - - private def addProjectToProjectsList(project: ProjectType): ListBuffer[ProjectType] = { - workspace.projects += project - } - - def unloadCpgByProjectName(name: String): Unit = { - projectByName(name).foreach { record => - record.cpg.foreach(_.close) - record.cpg = None - } - } - - def reloadCpgByName(name: String, loadCpg: String => Option[Cpg]): Option[Cpg] = { - projectByName(name).flatMap { record => - record.cpg = loadCpg(record.name) - record.cpg - } - } - - private def unloadCpgIfExists(name: String): Unit = { - projectByName(File(projectDir(name)).name) - .flatMap(_.cpg) - .foreach { c => - try { - c.close - } catch { - case _: IllegalStateException => // Store is already closed - } - } - } - - /** Remove currently active project from workspace and delete all associated workspace files from disk. - */ - def deleteCurrentProject(): Unit = { - val project = projectByCpg(cpg) - project match { - case Some(p) => deleteProject(p) - case None => - report(s"Project for active CPG does not exist") - } - } - - /** Remove project with name `name` from workspace and delete all associated workspace files from disk. - * @param name - * the name of the project that should be removed - */ - def deleteProject(name: String): Option[Unit] = { - val project = projectByName(name) - project match { - case Some(p) => - deleteProject(p) - Option[Unit](()) - case None => - report(s"Project with name $name does not exist") - None - } - } - - private def deleteProject(project: Project): Unit = { - removeProjectFromList(project.name) - if (project.path.toString != "") { - File(project.path).delete() - } - } - - private def removeProjectFromList(name: String): Option[ProjectType] = { - workspace.projects.zipWithIndex - .find { case (record, _) => - record.name == name - } - .map(_._2) - .map { index => - workspace.projects.remove(index) - } - } - - // Kept for backward compatibility - @deprecated("", "") - def recordExists(inputPath: String): Boolean = projectExists(inputPath) - @deprecated("", "") - def baseCpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = - cpgExists(inputPath, isLegacy) - -} - -object WorkspaceManager { - - private val BASE_CPG_FILENAME = "cpg.bin" - - def overlayFilesForDir(dirName: String): List[File] = { - File(dirName).list - .filter(f => f.isRegularFile && f.name != BASE_CPG_FILENAME) - .toList - .sortBy(_.name) - } + def reloadCpgByName(name: String, loadCpg: String => Option[Cpg]): Option[Cpg] = + projectByName(name).flatMap { record => + record.cpg = loadCpg(record.name) + record.cpg + } -} + private def unloadCpgIfExists(name: String): Unit = + projectByName(File(projectDir(name)).name) + .flatMap(_.cpg) + .foreach { c => + try + c.close + catch + case _: IllegalStateException => // Store is already closed + } + + /** Remove currently active project from workspace and delete all associated workspace files + * from disk. + */ + def deleteCurrentProject(): Unit = + val project = projectByCpg(cpg) + project match + case Some(p) => deleteProject(p) + case None => + report(s"Project for active CPG does not exist") + + /** Remove project with name `name` from workspace and delete all associated workspace files + * from disk. + * @param name + * the name of the project that should be removed + */ + def deleteProject(name: String): Option[Unit] = + val project = projectByName(name) + project match + case Some(p) => + deleteProject(p) + Option[Unit](()) + case None => + report(s"Project with name $name does not exist") + None + + private def deleteProject(project: Project): Unit = + removeProjectFromList(project.name) + if project.path.toString != "" then + File(project.path).delete() + + private def removeProjectFromList(name: String): Option[ProjectType] = + workspace.projects.zipWithIndex + .find { case (record, _) => + record.name == name + } + .map(_._2) + .map { index => + workspace.projects.remove(index) + } + + // Kept for backward compatibility + @deprecated("", "") + def recordExists(inputPath: String): Boolean = projectExists(inputPath) + + @deprecated("", "") + def baseCpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = + cpgExists(inputPath, isLegacy) +end WorkspaceManager + +object WorkspaceManager: + + private val BASE_CPG_FILENAME = "cpg.bin" + + def overlayFilesForDir(dirName: String): List[File] = + File(dirName).list + .filter(f => f.isRegularFile && f.name != BASE_CPG_FILENAME) + .toList + .sortBy(_.name) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala index f5f0cb97..8265e558 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala @@ -5,158 +5,178 @@ import io.shiftleft.codepropertygraph.generated.Operators import scala.annotation.unused -object DefaultSemantics { +object DefaultSemantics: - /** @return - * a default set of common external procedure calls for all languages. - */ - def apply(): Semantics = { - val list = operatorFlows ++ cFlows ++ javaFlows - Semantics.fromList(list) - } + /** @return + * a default set of common external procedure calls for all languages. + */ + def apply(): Semantics = + val list = operatorFlows ++ cFlows ++ javaFlows + Semantics.fromList(list) - private def F = (x: String, y: List[(Int, Int)]) => FlowSemantic.from(x, y) + private def F = (x: String, y: List[(Int, Int)]) => FlowSemantic.from(x, y) - private def PTF(x: String, ys: List[(Int, Int)] = List.empty): FlowSemantic = - FlowSemantic(x).copy(mappings = FlowSemantic.from(x, ys).mappings :+ PassThroughMapping) + private def PTF(x: String, ys: List[(Int, Int)] = List.empty): FlowSemantic = + FlowSemantic(x).copy(mappings = FlowSemantic.from(x, ys).mappings :+ PassThroughMapping) - def operatorFlows: List[FlowSemantic] = List( - F(Operators.addition, List((1, -1), (2, -1))), - F(Operators.addressOf, List((1, -1))), - F(Operators.assignment, List((2, 1), (2, -1))), - F(Operators.assignmentAnd, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentArithmeticShiftRight, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentDivision, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentExponentiation, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentLogicalShiftRight, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentMinus, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentModulo, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentMultiplication, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentOr, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentPlus, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentShiftLeft, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentXor, List((2, 1), (1, 1), (2, -1))), - F(Operators.cast, List((1, -1), (2, -1))), - F(Operators.computedMemberAccess, List((1, -1))), - F(Operators.conditional, List((2, -1), (3, -1))), - F(Operators.elvis, List((1, -1), (2, -1))), - F(Operators.notNullAssert, List((1, -1))), - F(Operators.fieldAccess, List((1, -1))), - F(Operators.getElementPtr, List((1, -1))), + def operatorFlows: List[FlowSemantic] = List( + F(Operators.addition, List((1, -1), (2, -1))), + F(Operators.addressOf, List((1, -1))), + F(Operators.assignment, List((2, 1), (2, -1))), + F(Operators.assignmentAnd, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentArithmeticShiftRight, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentDivision, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentExponentiation, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentLogicalShiftRight, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentMinus, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentModulo, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentMultiplication, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentOr, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentPlus, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentShiftLeft, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentXor, List((2, 1), (1, 1), (2, -1))), + F(Operators.cast, List((1, -1), (2, -1))), + F(Operators.computedMemberAccess, List((1, -1))), + F(Operators.conditional, List((2, -1), (3, -1))), + F(Operators.elvis, List((1, -1), (2, -1))), + F(Operators.notNullAssert, List((1, -1))), + F(Operators.fieldAccess, List((1, -1))), + F(Operators.getElementPtr, List((1, -1))), - // TODO does this still exist? - F(".incBy", List((1, 1), (2, 1), (3, 1), (4, 1))), - F(Operators.indexAccess, List((1, -1))), - F(Operators.indirectComputedMemberAccess, List((1, -1))), - F(Operators.indirectFieldAccess, List((1, -1))), - F(Operators.indirectIndexAccess, List((1, -1), (2, 1))), - F(Operators.indirectMemberAccess, List((1, -1))), - F(Operators.indirection, List((1, -1))), - F(Operators.memberAccess, List((1, -1))), - F(Operators.pointerShift, List((1, -1))), - F(Operators.postDecrement, List((1, 1), (1, -1))), - F(Operators.postIncrement, List((1, 1), (1, -1))), - F(Operators.preDecrement, List((1, 1), (1, -1))), - F(Operators.preIncrement, List((1, 1), (1, -1))), - F(Operators.sizeOf, List.empty[(Int, Int)]), + // TODO does this still exist? + F(".incBy", List((1, 1), (2, 1), (3, 1), (4, 1))), + F(Operators.indexAccess, List((1, -1))), + F(Operators.indirectComputedMemberAccess, List((1, -1))), + F(Operators.indirectFieldAccess, List((1, -1))), + F(Operators.indirectIndexAccess, List((1, -1), (2, 1))), + F(Operators.indirectMemberAccess, List((1, -1))), + F(Operators.indirection, List((1, -1))), + F(Operators.memberAccess, List((1, -1))), + F(Operators.pointerShift, List((1, -1))), + F(Operators.postDecrement, List((1, 1), (1, -1))), + F(Operators.postIncrement, List((1, 1), (1, -1))), + F(Operators.preDecrement, List((1, 1), (1, -1))), + F(Operators.preIncrement, List((1, 1), (1, -1))), + F(Operators.sizeOf, List.empty[(Int, Int)]), - // some of those operators have duplicate mappings due to a typo - // - see https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1630 + // some of those operators have duplicate mappings due to a typo + // - see https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1630 - F(".assignmentExponentiation", List((2, 1), (1, 1))), - F(".assignmentModulo", List((2, 1), (1, 1))), - F(".assignmentShiftLeft", List((2, 1), (1, 1))), - F(".assignmentLogicalShiftRight", List((2, 1), (1, 1))), - F(".assignmentArithmeticShiftRight", List((2, 1), (1, 1))), - F(".assignmentAnd", List((2, 1), (1, 1))), - F(".assignmentOr", List((2, 1), (1, 1))), - F(".assignmentXor", List((2, 1), (1, 1))), + F(".assignmentExponentiation", List((2, 1), (1, 1))), + F(".assignmentModulo", List((2, 1), (1, 1))), + F(".assignmentShiftLeft", List((2, 1), (1, 1))), + F(".assignmentLogicalShiftRight", List((2, 1), (1, 1))), + F(".assignmentArithmeticShiftRight", List((2, 1), (1, 1))), + F(".assignmentAnd", List((2, 1), (1, 1))), + F(".assignmentOr", List((2, 1), (1, 1))), + F(".assignmentXor", List((2, 1), (1, 1))), - // Language specific operators - PTF(".tupleLiteral"), - PTF(".dictLiteral"), - PTF(".setLiteral"), - PTF(".listLiteral") - ) + // Language specific operators + PTF(".tupleLiteral"), + PTF(".dictLiteral"), + PTF(".setLiteral"), + PTF(".listLiteral") + ) - /** Semantic summaries for common external C/C++ calls. - * - * @see - * Standard - * C Library Functions - */ - def cFlows: List[FlowSemantic] = List( - F("abs", List((1, 1), (1, -1))), - F("abort", List.empty[(Int, Int)]), - F("asctime", List((1, 1), (1, -1))), - F("asctime_r", List((1, 1), (1, -1))), - F("atof", List((1, 1), (1, -1))), - F("atoi", List((1, 1), (1, -1))), - F("atol", List((1, 1), (1, -1))), - F("calloc", List((1, -1), (2, -1))), - F("ceil", List((1, 1), (1, 1))), - F("clock", List.empty[(Int, Int)]), - F("ctime", List((1, -1))), - F("ctime64", List((1, -1))), - F("ctime_r", List((1, -1))), - F("ctime64_r", List((1, -1))), - F("difftime", List((1, -1), (2, -1))), - F("difftime64", List((1, -1), (2, -1))), - PTF("div"), - F("exit", List((1, 1))), - F("exp", List((1, -1))), - F("fabs", List((1, -1))), - F("fclose", List((1, 1), (1, -1))), - F("fdopen", List((1, -1), (2, -1))), - F("feof", List((1, 1), (1, -1))), - F("ferror", List((1, 1), (1, -1))), - F("fflush", List((1, 1), (1, -1))), - F("fgetc", List((1, 1), (1, -1))), - F("fwrite", List((1, 1), (1, -1), (2, -1), (3, -1), (4, -1))), - F("free", List((1, 1))), - F("getc", List((1, 1))), - F("scanf", List((2, 2))), - F("strcmp", List((1, 1), (1, -1), (2, 2), (2, -1))), - F("strlen", List((1, 1), (1, -1))), - F("strncpy", List((1, 1), (2, 2), (3, 3), (1, -1), (2, -1))), - F("strncat", List((1, 1), (1, -1), (2, 2), (2, -1))) - ) + /** Semantic summaries for common external C/C++ calls. + * + * @see + * Standard + * C Library Functions + */ + def cFlows: List[FlowSemantic] = List( + F("abs", List((1, 1), (1, -1))), + F("abort", List.empty[(Int, Int)]), + F("asctime", List((1, 1), (1, -1))), + F("asctime_r", List((1, 1), (1, -1))), + F("atof", List((1, 1), (1, -1))), + F("atoi", List((1, 1), (1, -1))), + F("atol", List((1, 1), (1, -1))), + F("calloc", List((1, -1), (2, -1))), + F("ceil", List((1, 1), (1, 1))), + F("clock", List.empty[(Int, Int)]), + F("ctime", List((1, -1))), + F("ctime64", List((1, -1))), + F("ctime_r", List((1, -1))), + F("ctime64_r", List((1, -1))), + F("difftime", List((1, -1), (2, -1))), + F("difftime64", List((1, -1), (2, -1))), + PTF("div"), + F("exit", List((1, 1))), + F("exp", List((1, -1))), + F("fabs", List((1, -1))), + F("fclose", List((1, 1), (1, -1))), + F("fdopen", List((1, -1), (2, -1))), + F("feof", List((1, 1), (1, -1))), + F("ferror", List((1, 1), (1, -1))), + F("fflush", List((1, 1), (1, -1))), + F("fgetc", List((1, 1), (1, -1))), + F("fwrite", List((1, 1), (1, -1), (2, -1), (3, -1), (4, -1))), + F("free", List((1, 1))), + F("getc", List((1, 1))), + F("scanf", List((2, 2))), + F("strcmp", List((1, 1), (1, -1), (2, 2), (2, -1))), + F("strlen", List((1, 1), (1, -1))), + F("strncpy", List((1, 1), (2, 2), (3, 3), (1, -1), (2, -1))), + F("strncat", List((1, 1), (1, -1), (2, 2), (2, -1))) + ) - /** Semantic summaries for common external Java calls. - */ - def javaFlows: List[FlowSemantic] = List( - PTF("java.lang.String.split:java.lang.String[](java.lang.String)", List((0, 0))), - PTF("java.lang.String.split:java.lang.String[](java.lang.String,int)", List((0, 0))), - PTF("java.lang.String.compareTo:int(java.lang.String)", List((0, 0))), - F("java.io.PrintWriter.print:void(java.lang.String)", List((0, 0), (1, 1))), - F("java.io.PrintWriter.println:void(java.lang.String)", List((0, 0), (1, 1))), - F("java.io.PrintStream.println:void(java.lang.String)", List((0, 0), (1, 1))), - PTF("java.io.PrintStream.print:void(java.lang.String)", List((0, 0))), - F("android.text.TextUtils.isEmpty:boolean(java.lang.String)", List((0, -1), (1, -1))), - F("java.sql.PreparedStatement.prepareStatement:java.sql.PreparedStatement(java.lang.String)", List((1, -1))), - F("java.sql.PreparedStatement.prepareStatement:setDouble(int,double)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setFloat(int,float)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setInt(int,int)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setLong(int,long)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setShort(int,short)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setString(int,java.lang.String)", List((1, 1), (2, 2))), - F("org.apache.http.HttpRequest.:void(org.apache.http.RequestLine)", List((1, 1), (1, 0))), - F("org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String)", List((1, 1), (1, 0), (2, 0))), - F( - "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String,org.apache.http.ProtocolVersion)", - List((1, 1), (1, 0), (2, 2), (2, 0), (3, 3), (3, 0)) - ), - F("org.apache.http.HttpResponse.getStatusLine:org.apache.http.StatusLine()", List((0, -1))), - F("org.apache.http.HttpResponse.setStatusLine:void(org.apache.http.StatusLine)", List((1, 0), (1, 1), (0, -1))), - F("org.apache.http.HttpResponse.setReasonPhrase:void(java.lang.String)", List((1, 0), (1, 1), (0, -1))), - F("org.apache.http.HttpResponse.getEntity:org.apache.http.HttpEntity()", List((0, -1))), - F("org.apache.http.HttpResponse.setEntity:void(org.apache.http.HttpEntity)", List((1, 0), (1, 1), (1, 0))) - ) + /** Semantic summaries for common external Java calls. + */ + def javaFlows: List[FlowSemantic] = List( + PTF("java.lang.String.split:java.lang.String[](java.lang.String)", List((0, 0))), + PTF("java.lang.String.split:java.lang.String[](java.lang.String,int)", List((0, 0))), + PTF("java.lang.String.compareTo:int(java.lang.String)", List((0, 0))), + F("java.io.PrintWriter.print:void(java.lang.String)", List((0, 0), (1, 1))), + F("java.io.PrintWriter.println:void(java.lang.String)", List((0, 0), (1, 1))), + F("java.io.PrintStream.println:void(java.lang.String)", List((0, 0), (1, 1))), + PTF("java.io.PrintStream.print:void(java.lang.String)", List((0, 0))), + F("android.text.TextUtils.isEmpty:boolean(java.lang.String)", List((0, -1), (1, -1))), + F( + "java.sql.PreparedStatement.prepareStatement:java.sql.PreparedStatement(java.lang.String)", + List((1, -1)) + ), + F("java.sql.PreparedStatement.prepareStatement:setDouble(int,double)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setFloat(int,float)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setInt(int,int)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setLong(int,long)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setShort(int,short)", List((1, 1), (2, 2))), + F( + "java.sql.PreparedStatement.prepareStatement:setString(int,java.lang.String)", + List((1, 1), (2, 2)) + ), + F( + "org.apache.http.HttpRequest.:void(org.apache.http.RequestLine)", + List((1, 1), (1, 0)) + ), + F( + "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String)", + List((1, 1), (1, 0), (2, 0)) + ), + F( + "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String,org.apache.http.ProtocolVersion)", + List((1, 1), (1, 0), (2, 2), (2, 0), (3, 3), (3, 0)) + ), + F("org.apache.http.HttpResponse.getStatusLine:org.apache.http.StatusLine()", List((0, -1))), + F( + "org.apache.http.HttpResponse.setStatusLine:void(org.apache.http.StatusLine)", + List((1, 0), (1, 1), (0, -1)) + ), + F( + "org.apache.http.HttpResponse.setReasonPhrase:void(java.lang.String)", + List((1, 0), (1, 1), (0, -1)) + ), + F("org.apache.http.HttpResponse.getEntity:org.apache.http.HttpEntity()", List((0, -1))), + F( + "org.apache.http.HttpResponse.setEntity:void(org.apache.http.HttpEntity)", + List((1, 0), (1, 1), (1, 0)) + ) + ) - /** @return - * procedure semantics for operators and common external Java calls only. - */ - @unused - def javaSemantics(): Semantics = Semantics.fromList(operatorFlows ++ javaFlows) - -} + /** @return + * procedure semantics for operators and common external Java calls only. + */ + @unused + def javaSemantics(): Semantics = Semantics.fromList(operatorFlows ++ javaFlows) +end DefaultSemantics diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala index 14608f34..2ae58538 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala @@ -13,97 +13,113 @@ import overflowdb.traversal.jIteratortoTraversal import scala.collection.mutable -class DdgGenerator { - - val edgeType = "DDG" - private val edgeCache = mutable.Map[StoredNode, List[Edge]]() - - def generate(methodNode: Method)(implicit semantics: Semantics = DefaultSemantics()): Graph = { - val entryNode = methodNode - val paramNodes = methodNode.parameter.l - val allOtherNodes = methodNode.cfgNode.l - val exitNode = methodNode.methodReturn - val allNodes: List[StoredNode] = List(entryNode, exitNode) ++ paramNodes ++ allOtherNodes - val visibleNodes = allNodes.filter(shouldBeDisplayed) - - val edges = visibleNodes.map { dstNode => - inEdgesToDisplay(dstNode) - } - - val allIdsReferencedByEdges = edges.flatten.flatMap { edge => - Set(edge.src.id, edge.dst.id) - } - - val ddgNodes = visibleNodes - .filter(node => allIdsReferencedByEdges.contains(node.id)) - .map(surroundingCall) - .filterNot(node => node.isInstanceOf[Call] && isGenericMemberAccessName(node.asInstanceOf[Call].name)) - - val ddgEdges = edges.flatten - .map { edge => - edge.copy(src = surroundingCall(edge.src), dst = surroundingCall(edge.dst)) - } - .filter(e => e.src != e.dst) - .filterNot(e => e.dst.isInstanceOf[Call] && isGenericMemberAccessName(e.dst.asInstanceOf[Call].name)) - .filterNot(e => e.src.isInstanceOf[Call] && isGenericMemberAccessName(e.src.asInstanceOf[Call].name)) - .distinct - - edgeCache.clear() - Graph(ddgNodes, ddgEdges) - } - - private def surroundingCall(node: StoredNode): StoredNode = { - node match { - case arg: Expression => arg.inCall.headOption.getOrElse(node) - case _ => node - } - } - - private def shouldBeDisplayed(v: Node): Boolean = !( - v.isInstanceOf[ControlStructure] || - v.isInstanceOf[JumpTarget] - ) - - private def inEdgesToDisplay(dstNode: StoredNode, visited: List[StoredNode] = List())(implicit - semantics: Semantics - ): List[Edge] = { - - if (edgeCache.contains(dstNode)) { - return edgeCache(dstNode) - } - - if (visited.contains(dstNode)) { - List() - } else { - val parents = expand(dstNode) - val (visible, invisible) = parents.partition(x => shouldBeDisplayed(x.src) && x.srcVisible) - val result = visible.toList ++ invisible.toList.flatMap { n => - val parentInEdgesToDisplay = inEdgesToDisplay(n.src, visited ++ List(dstNode)) - parentInEdgesToDisplay.map(y => Edge(y.src, dstNode, y.srcVisible, edgeType = edgeType, label = y.label)) - }.distinct - edgeCache.put(dstNode, result) - result - } - } - - private def expand(v: StoredNode)(implicit semantics: Semantics): Iterator[Edge] = { - - val allInEdges = v - .inE(EdgeTypes.REACHING_DEF) - .map(x => - Edge(x.outNode.asInstanceOf[StoredNode], v, srcVisible = true, x.property(Properties.VARIABLE), edgeType) - ) - - v match { - case cfgNode: CfgNode => - cfgNode - .ddgInPathElem(withInvisible = true) - .map(x => Edge(x.node.asInstanceOf[StoredNode], v, x.visible, x.outEdgeLabel, edgeType)) - .iterator ++ allInEdges.filter(_.src.isInstanceOf[Method]).iterator - case _ => - allInEdges.iterator - } - - } - -} +class DdgGenerator: + + val edgeType = "DDG" + private val edgeCache = mutable.Map[StoredNode, List[Edge]]() + + def generate(methodNode: Method)(implicit semantics: Semantics = DefaultSemantics()): Graph = + val entryNode = methodNode + val paramNodes = methodNode.parameter.l + val allOtherNodes = methodNode.cfgNode.l + val exitNode = methodNode.methodReturn + val allNodes: List[StoredNode] = List(entryNode, exitNode) ++ paramNodes ++ allOtherNodes + val visibleNodes = allNodes.filter(shouldBeDisplayed) + + val edges = visibleNodes.map { dstNode => + inEdgesToDisplay(dstNode) + } + + val allIdsReferencedByEdges = edges.flatten.flatMap { edge => + Set(edge.src.id, edge.dst.id) + } + + val ddgNodes = visibleNodes + .filter(node => allIdsReferencedByEdges.contains(node.id)) + .map(surroundingCall) + .filterNot(node => + node.isInstanceOf[Call] && isGenericMemberAccessName(node.asInstanceOf[Call].name) + ) + + val ddgEdges = edges.flatten + .map { edge => + edge.copy(src = surroundingCall(edge.src), dst = surroundingCall(edge.dst)) + } + .filter(e => e.src != e.dst) + .filterNot(e => + e.dst.isInstanceOf[Call] && isGenericMemberAccessName(e.dst.asInstanceOf[Call].name) + ) + .filterNot(e => + e.src.isInstanceOf[Call] && isGenericMemberAccessName(e.src.asInstanceOf[Call].name) + ) + .distinct + + edgeCache.clear() + Graph(ddgNodes, ddgEdges) + end generate + + private def surroundingCall(node: StoredNode): StoredNode = + node match + case arg: Expression => arg.inCall.headOption.getOrElse(node) + case _ => node + + private def shouldBeDisplayed(v: Node): Boolean = !( + v.isInstanceOf[ControlStructure] || + v.isInstanceOf[JumpTarget] + ) + + private def inEdgesToDisplay(dstNode: StoredNode, visited: List[StoredNode] = List())(implicit + semantics: Semantics + ): List[Edge] = + + if edgeCache.contains(dstNode) then + return edgeCache(dstNode) + + if visited.contains(dstNode) then + List() + else + val parents = expand(dstNode) + val (visible, invisible) = + parents.partition(x => shouldBeDisplayed(x.src) && x.srcVisible) + val result = visible.toList ++ invisible.toList.flatMap { n => + val parentInEdgesToDisplay = inEdgesToDisplay(n.src, visited ++ List(dstNode)) + parentInEdgesToDisplay.map(y => + Edge(y.src, dstNode, y.srcVisible, edgeType = edgeType, label = y.label) + ) + }.distinct + edgeCache.put(dstNode, result) + result + end inEdgesToDisplay + + private def expand(v: StoredNode)(implicit semantics: Semantics): Iterator[Edge] = + + val allInEdges = v + .inE(EdgeTypes.REACHING_DEF) + .map(x => + Edge( + x.outNode.asInstanceOf[StoredNode], + v, + srcVisible = true, + x.property(Properties.VARIABLE), + edgeType + ) + ) + + v match + case cfgNode: CfgNode => + cfgNode + .ddgInPathElem(withInvisible = true) + .map(x => + Edge( + x.node.asInstanceOf[StoredNode], + v, + x.visible, + x.outEdgeLabel, + edgeType + ) + ) + .iterator ++ allInEdges.filter(_.src.isInstanceOf[Method]).iterator + case _ => + allInEdges.iterator + end expand +end DdgGenerator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala index dcc15aa0..685e3dbd 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala @@ -3,19 +3,23 @@ package io.appthreat.dataflowengineoss.dotgenerator import io.appthreat.dataflowengineoss.DefaultSemantics import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes.Method -import io.shiftleft.semanticcpg.dotgenerator.{AstGenerator, CdgGenerator, CfgGenerator, DotSerializer} - -object DotCpg14Generator { +import io.shiftleft.semanticcpg.dotgenerator.{ + AstGenerator, + CdgGenerator, + CfgGenerator, + DotSerializer +} - def toDotCpg14(traversal: Iterator[Method])(implicit semantics: Semantics = DefaultSemantics()): Iterator[String] = - traversal.map(dotGraphForMethod) +object DotCpg14Generator: - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = { - val ast = new AstGenerator().generate(method) - val cfg = new CfgGenerator().generate(method) - val ddg = new DdgGenerator().generate(method) - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), ast ++ cfg ++ ddg ++ cdg, withEdgeTypes = true) - } + def toDotCpg14(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) -} + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ast = new AstGenerator().generate(method) + val cfg = new CfgGenerator().generate(method) + val ddg = new DdgGenerator().generate(method) + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), ast ++ cfg ++ ddg ++ cdg, withEdgeTypes = true) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala index 4a587b40..2c5ffdef 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala @@ -5,15 +5,14 @@ import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.semanticcpg.dotgenerator.DotSerializer -object DotDdgGenerator { +object DotDdgGenerator: - def toDotDdg(traversal: Iterator[Method])(implicit semantics: Semantics = DefaultSemantics()): Iterator[String] = - traversal.map(dotGraphForMethod) + def toDotDdg(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = { - val ddgGenerator = new DdgGenerator() - val ddg = ddgGenerator.generate(method) - DotSerializer.dotGraph(Option(method), ddg) - } - -} + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ddgGenerator = new DdgGenerator() + val ddg = ddgGenerator.generate(method) + DotSerializer.dotGraph(Option(method), ddg) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala index 34f1bab9..78818e10 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala @@ -5,15 +5,14 @@ import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.semanticcpg.dotgenerator.{CdgGenerator, DotSerializer} -object DotPdgGenerator { +object DotPdgGenerator: - def toDotPdg(traversal: Iterator[Method])(implicit semantics: Semantics = DefaultSemantics()): Iterator[String] = - traversal.map(dotGraphForMethod) + def toDotPdg(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = { - val ddg = new DdgGenerator().generate(method) - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), ddg.++(cdg), withEdgeTypes = true) - } - -} + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ddg = new DdgGenerator().generate(method) + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), ddg.++(cdg), withEdgeTypes = true) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala index 25084fbc..67263a84 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala @@ -1,10 +1,10 @@ package io.appthreat.dataflowengineoss.language import io.appthreat.dataflowengineoss.queryengine.{ - Engine, - EngineContext, - SourcesToStartingPoints, - StartingPointWithSource + Engine, + EngineContext, + SourcesToStartingPoints, + StartingPointWithSource } import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.appthreat.dataflowengineoss.queryengine.* @@ -18,92 +18,101 @@ import scala.collection.parallel.CollectionConverters.* /** Base class for nodes that can occur in data flows */ -class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal { +class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal: - def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = { - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache)) - cache.clear() - result - } + def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = + traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache)) + cache.clear() + result - def ddgInPathElem(implicit semantics: Semantics = DefaultSemantics()): Iterator[PathElement] = { - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = traversal.flatMap(x => x.ddgInPathElem(Vector(PathElement(x)), withInvisible = false, cache)) - cache.clear() - result - } + def ddgInPathElem(implicit semantics: Semantics = DefaultSemantics()): Iterator[PathElement] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = traversal.flatMap(x => + x.ddgInPathElem(Vector(PathElement(x)), withInvisible = false, cache) + ) + cache.clear() + result - def reachableBy[NodeType](sourceTrav: IterableOnce[NodeType], sourceTravs: IterableOnce[NodeType]*)(implicit - context: EngineContext - ): Iterator[NodeType] = { - val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) - val reachedSources = - reachableByInternal(sources).map(_.path.head.node) - reachedSources.cast[NodeType] - } + def reachableBy[NodeType]( + sourceTrav: IterableOnce[NodeType], + sourceTravs: IterableOnce[NodeType]* + )(implicit context: EngineContext): Iterator[NodeType] = + val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) + val reachedSources = + reachableByInternal(sources).map(_.path.head.node) + reachedSources.cast[NodeType] - def df[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit - context: EngineContext - ): Iterator[Path] = reachableByFlows(sourceTrav, sourceTravs) + def df[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit + context: EngineContext + ): Iterator[Path] = reachableByFlows(sourceTrav, sourceTravs) - def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit - context: EngineContext - ): Iterator[Path] = { - val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) - val startingPoints = sources.map(_.startingPoint) - val paths = reachableByInternal(sources).par - .map { result => - // We can get back results that start in nodes that are invisible - // according to the semantic, e.g., arguments that are only used - // but not defined. We filter these results here prior to returning - val first = result.path.headOption - if (first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node)) { - None - } else { - val visiblePathElements = result.path.filter(x => startingPoints.contains(x.node) || x.visible) - Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node)))) - } - } - .filter(_.isDefined) - .dedup - .flatten - .toVector - paths.iterator - } - - def reachableByDetailed[NodeType](sourceTrav: Iterator[NodeType], sourceTravs: Iterator[NodeType]*)(implicit - context: EngineContext - ): Vector[TableEntry] = { - val sources = SourcesToStartingPoints.sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) - reachableByInternal(sources) - } + def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit + context: EngineContext + ): Iterator[Path] = + val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) + val startingPoints = sources.map(_.startingPoint) + val paths = reachableByInternal(sources).par + .map { result => + // We can get back results that start in nodes that are invisible + // according to the semantic, e.g., arguments that are only used + // but not defined. We filter these results here prior to returning + val first = result.path.headOption + if first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node) + then + None + else + val visiblePathElements = + result.path.filter(x => startingPoints.contains(x.node) || x.visible) + Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node)))) + } + .filter(_.isDefined) + .dedup + .flatten + .toVector + paths.iterator + end reachableByFlows - private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = { - l.headOption.map(x => x :: l.sliding(2).collect { case Seq(a, b) if a != b => b }.toList).getOrElse(List()) - } + def reachableByDetailed[NodeType]( + sourceTrav: Iterator[NodeType], + sourceTravs: Iterator[NodeType]* + )(implicit context: EngineContext): Vector[TableEntry] = + val sources = + SourcesToStartingPoints.sourceTravsToStartingPoints(sourceTrav +: sourceTravs: _*) + reachableByInternal(sources) - private def reachableByInternal( - startingPointsWithSources: List[StartingPointWithSource] - )(implicit context: EngineContext): Vector[TableEntry] = { - val sinks = traversal.dedup.toList.sortBy(_.id) - val engine = new Engine(context) - val result = engine.backwards(sinks, startingPointsWithSources.map(_.startingPoint)) + private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = + l.headOption.map(x => + x :: l.sliding(2).collect { case Seq(a, b) if a != b => b }.toList + ).getOrElse(List()) - engine.shutdown() - val sources = startingPointsWithSources.map(_.source) - val startingPointToSource = startingPointsWithSources.map { x => - x.startingPoint.asInstanceOf[AstNode] -> x.source - }.toMap - val res = result.par.map { r => - val startingPoint = r.path.head.node - if (sources.contains(startingPoint) || !startingPointToSource(startingPoint).isInstanceOf[AstNode]) { - r - } else { - r.copy(path = PathElement(startingPointToSource(startingPoint).asInstanceOf[AstNode]) +: r.path) - } - } - res.toVector - } + private def reachableByInternal( + startingPointsWithSources: List[StartingPointWithSource] + )(implicit context: EngineContext): Vector[TableEntry] = + val sinks = traversal.dedup.toList.sortBy(_.id) + val engine = new Engine(context) + val result = engine.backwards(sinks, startingPointsWithSources.map(_.startingPoint)) -} + engine.shutdown() + val sources = startingPointsWithSources.map(_.source) + val startingPointToSource = startingPointsWithSources.map { x => + x.startingPoint.asInstanceOf[AstNode] -> x.source + }.toMap + val res = result.par.map { r => + val startingPoint = r.path.head.node + if sources.contains(startingPoint) || !startingPointToSource( + startingPoint + ).isInstanceOf[AstNode] + then + r + else + r.copy(path = + PathElement( + startingPointToSource(startingPoint).asInstanceOf[AstNode] + ) +: r.path + ) + } + res.toVector + end reachableByInternal +end ExtendedCfgNode diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala index 9efdd62b..dbb0fe6b 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala @@ -7,194 +7,225 @@ import org.apache.commons.lang.StringUtils import scala.collection.mutable.{ArrayBuffer, Set} -case class Path(elements: List[AstNode]) { - def resultPairs(): List[(String, Option[Integer])] = { - val pairs = elements.map { - case point: MethodParameterIn => - val method = point.method - val method_name = method.name - val code = s"$method_name(${method.parameter.l.sortBy(_.order).map(_.code).mkString(", ")})" - (code, point.lineNumber) - case point => (point.statement.repr, point.lineNumber) - } - pairs.headOption - .map(x => x :: pairs.sliding(2).collect { case Seq(a, b) if a._1 != b._1 && a._2 != b._2 => b }.toList) - .getOrElse(List()) - } -} +case class Path(elements: List[AstNode]): + def resultPairs(): List[(String, Option[Integer])] = + val pairs = elements.map { + case point: MethodParameterIn => + val method = point.method + val method_name = method.name + val code = + s"$method_name(${method.parameter.l.sortBy(_.order).map(_.code).mkString(", ")})" + (code, point.lineNumber) + case point => (point.statement.repr, point.lineNumber) + } + pairs.headOption + .map(x => + x :: pairs.sliding(2).collect { + case Seq(a, b) if a._1 != b._1 && a._2 != b._2 => b + }.toList + ) + .getOrElse(List()) -object Path { +object Path: - private val DefaultMaxTrackedWidth = 128 - // TODO replace with dynamic rendering based on the terminal's width, e.g. in scala-repl-pp - private lazy val maxTrackedWidth = - sys.env.get("CHEN_DATAFLOW_TRACKED_WIDTH").map(_.toInt).getOrElse(DefaultMaxTrackedWidth) + private val DefaultMaxTrackedWidth = 128 + // TODO replace with dynamic rendering based on the terminal's width, e.g. in scala-repl-pp + private lazy val maxTrackedWidth = + sys.env.get("CHEN_DATAFLOW_TRACKED_WIDTH").map(_.toInt).getOrElse(DefaultMaxTrackedWidth) - private def tagAsString(tag: Iterator[Tag]) = - if (tag.nonEmpty) tag.name.mkString(", ") else "" + private def tagAsString(tag: Iterator[Tag]) = + if tag.nonEmpty then tag.name.mkString(", ") else "" - implicit val show: Show[Path] = { path => - var caption = "" - if (path.elements.size > 2) { - val srcNode = path.elements.head - val srcTags = tagAsString(srcNode.tag) - val sinkNode = path.elements.last - var sinkCode = sinkNode.code - val sinkTags = tagAsString(sinkNode.tag) - sinkNode match { - case cfgNode: CfgNode => - val method = cfgNode.method - sinkCode = method.fullName - } - caption = if (srcNode.code != "this") s"Source: ${srcNode.code}" else "" - if (srcTags.nonEmpty) caption += s"\nSource Tags: ${srcTags}" - caption += s"\nSink: ${sinkCode}\n" - if (sinkTags.nonEmpty) caption += s"Sink Tags: ${sinkTags}\n" - } - var hasCheckLike: Boolean = false; - val tableRows = ArrayBuffer[Array[String]]() - val addedPaths = Set[String]() - path.elements.foreach { astNode => - val nodeType = astNode.getClass.getSimpleName - val lineNumber = astNode.lineNumber.getOrElse("").toString - val fileName = astNode.file.name.headOption.getOrElse("").replace("", "") - var fileLocation = s"${fileName}#${lineNumber}" - var tags: String = tagAsString(astNode.tag) - if (fileLocation == "#") fileLocation = "N/A" - astNode match { - case _: MethodReturn => - case methodParameterIn: MethodParameterIn => - val methodName = methodParameterIn.method.name - if (tags.isEmpty && methodParameterIn.method.tag.nonEmpty) { - tags = tagAsString(methodParameterIn.method.tag) - } - if (tags.isEmpty && methodParameterIn.tag.nonEmpty) { - tags = tagAsString(methodParameterIn.tag) - } - tableRows += Array[String]( - "methodParameterIn", - fileLocation, - methodName, - s"[bold red]${methodParameterIn.name}[/bold red]", - methodParameterIn.method.fullName + (if (methodParameterIn.method.isExternal) " :right_arrow_curving_up:" - else ""), - tags - ) - case ret: Return => - val methodName = ret.method.name - tableRows += Array[String]("return", fileLocation, methodName, ret.argumentName.getOrElse(""), ret.code, tags) - case identifier: Identifier => - val methodName = identifier.method.name - if (tags.isEmpty && identifier.inCall.nonEmpty && identifier.inCall.head.tag.nonEmpty) { - tags = tagAsString(identifier.inCall.head.tag) - } - if (!addedPaths.contains(s"${fileName}#${lineNumber}") && identifier.inCall.nonEmpty) { - tableRows += Array[String]( - "identifier", - fileLocation, - methodName, - identifier.name, - if (identifier.inCall.nonEmpty) - identifier.inCall.head.code - else identifier.code, - tags - ) - } - case member: Member => - val methodName = "" - tableRows += Array[String]("member", fileLocation, methodName, nodeType, member.name, member.code, tags) - case call: Call => - if (!call.code.startsWith(" - val method = cfgNode.method - if (tags.isEmpty && method.tag.nonEmpty) { - tags = tagAsString(method.tag) - } - val methodName = method.name - val statement = cfgNode match { - case _: MethodParameterIn => - if (tags.isEmpty && method.parameter.tag.nonEmpty) { - tags = tagAsString(method.parameter.tag) - } - val paramsPretty = method.parameter.toList.sortBy(_.index).map(_.code).mkString(", ") - s"$methodName($paramsPretty)" - case _ => - if (tags.isEmpty && cfgNode.statement.tag.nonEmpty) { - tags = tagAsString(cfgNode.statement.tag) - } - cfgNode.statement.repr - } - val tracked = StringUtils.normalizeSpace(StringUtils.abbreviate(statement, maxTrackedWidth)) - tableRows += Array[String]("cfgNode", fileLocation, methodName, "", tracked, tags) - } - if (isCheckLike(tags)) hasCheckLike = true - addedPaths += s"${fileName}#${lineNumber}" - } - try { - if (hasCheckLike) caption = s"This flow is safe with mitigation in place.\n$caption" - printFlows(tableRows, caption) - } catch { - case exc: Exception => - } - caption - } + implicit val show: Show[Path] = path => + var caption = "" + if path.elements.size > 2 then + val srcNode = path.elements.head + val srcTags = tagAsString(srcNode.tag) + val sinkNode = path.elements.last + var sinkCode = sinkNode.code + val sinkTags = tagAsString(sinkNode.tag) + sinkNode match + case cfgNode: CfgNode => + val method = cfgNode.method + sinkCode = method.fullName + caption = if srcNode.code != "this" then s"Source: ${srcNode.code}" else "" + if srcTags.nonEmpty then caption += s"\nSource Tags: ${srcTags}" + caption += s"\nSink: ${sinkCode}\n" + if sinkTags.nonEmpty then caption += s"Sink Tags: ${sinkTags}\n" + var hasCheckLike: Boolean = false; + val tableRows = ArrayBuffer[Array[String]]() + val addedPaths = Set[String]() + path.elements.foreach { astNode => + val nodeType = astNode.getClass.getSimpleName + val lineNumber = astNode.lineNumber.getOrElse("").toString + val fileName = astNode.file.name.headOption.getOrElse("").replace("", "") + var fileLocation = s"${fileName}#${lineNumber}" + var tags: String = tagAsString(astNode.tag) + if fileLocation == "#" then fileLocation = "N/A" + astNode match + case _: MethodReturn => + case methodParameterIn: MethodParameterIn => + val methodName = methodParameterIn.method.name + if tags.isEmpty && methodParameterIn.method.tag.nonEmpty then + tags = tagAsString(methodParameterIn.method.tag) + if tags.isEmpty && methodParameterIn.tag.nonEmpty then + tags = tagAsString(methodParameterIn.tag) + tableRows += Array[String]( + "methodParameterIn", + fileLocation, + methodName, + s"[bold red]${methodParameterIn.name}[/bold red]", + methodParameterIn.method.fullName + (if methodParameterIn.method.isExternal + then " :right_arrow_curving_up:" + else ""), + tags + ) + case ret: Return => + val methodName = ret.method.name + tableRows += Array[String]( + "return", + fileLocation, + methodName, + ret.argumentName.getOrElse(""), + ret.code, + tags + ) + case identifier: Identifier => + val methodName = identifier.method.name + if tags.isEmpty && identifier.inCall.nonEmpty && identifier.inCall.head.tag.nonEmpty + then + tags = tagAsString(identifier.inCall.head.tag) + if !addedPaths.contains( + s"${fileName}#${lineNumber}" + ) && identifier.inCall.nonEmpty + then + tableRows += Array[String]( + "identifier", + fileLocation, + methodName, + identifier.name, + if identifier.inCall.nonEmpty then + identifier.inCall.head.code + else identifier.code, + tags + ) + case member: Member => + val methodName = "" + tableRows += Array[String]( + "member", + fileLocation, + methodName, + nodeType, + member.name, + member.code, + tags + ) + case call: Call => + if !call.code.startsWith(" + val method = cfgNode.method + if tags.isEmpty && method.tag.nonEmpty then + tags = tagAsString(method.tag) + val methodName = method.name + val statement = cfgNode match + case _: MethodParameterIn => + if tags.isEmpty && method.parameter.tag.nonEmpty then + tags = tagAsString(method.parameter.tag) + val paramsPretty = + method.parameter.toList.sortBy(_.index).map(_.code).mkString(", ") + s"$methodName($paramsPretty)" + case _ => + if tags.isEmpty && cfgNode.statement.tag.nonEmpty then + tags = tagAsString(cfgNode.statement.tag) + cfgNode.statement.repr + val tracked = StringUtils.normalizeSpace(StringUtils.abbreviate( + statement, + maxTrackedWidth + )) + tableRows += Array[String]( + "cfgNode", + fileLocation, + methodName, + "", + tracked, + tags + ) + end match + if isCheckLike(tags) then hasCheckLike = true + addedPaths += s"${fileName}#${lineNumber}" + } + try + if hasCheckLike then caption = s"This flow is safe with mitigation in place.\n$caption" + printFlows(tableRows, caption) + catch + case exc: Exception => + caption - private def addEmphasis(str: String, isCheckLike: Boolean): String = if (isCheckLike) s"[green]$str[/green]" else str + private def addEmphasis(str: String, isCheckLike: Boolean): String = + if isCheckLike then s"[green]$str[/green]" else str - private def simplifyFilePath(str: String): String = str.replace("src/main/java/", "").replace("src/main/scala/", "") + private def simplifyFilePath(str: String): String = + str.replace("src/main/java/", "").replace("src/main/scala/", "") - private def isCheckLike(tagsStr: String): Boolean = - tagsStr.contains("valid") || tagsStr.contains("encrypt") || tagsStr.contains("encode") || tagsStr.contains( - "transform" - ) || tagsStr.contains("check") + private def isCheckLike(tagsStr: String): Boolean = + tagsStr.contains("valid") || tagsStr.contains("encrypt") || tagsStr.contains( + "encode" + ) || tagsStr.contains( + "transform" + ) || tagsStr.contains("check") - private def printFlows(tableRows: ArrayBuffer[Array[String]], caption: String): Unit = { - val richTableLib = py.module("rich.table") - val richConsole = py.module("chenpy.logger").console - val table = richTableLib.Table(highlight = true, expand = true, caption = caption) - Array("Location", "Method", "Parameter", "Tracked").foreach(c => table.add_column(c)) - tableRows.foreach { row => - { - val end_section = row.head == "call" - val trow: Array[String] = row.tail - if (!trow(3).startsWith(" table.add_column(c)) + tableRows.foreach { row => + val end_section = row.head == "call" + val trow: Array[String] = row.tail + if !trow(3).startsWith(" - srcName == node.argumentName.get - case FlowMapping(ParameterNode(srcIndex, _), _) => srcIndex == node.argumentIndex - case PassThroughMapping if node.argumentIndex != 0 => true - case _ => false - }) - } + /** Determine whether evaluation of the call this argument is a part of results in usage of this + * argument. + */ + def isUsed(implicit semantics: Semantics): Boolean = + val s = semanticsForCallByArg + s.isEmpty || s.exists(_.mappings.exists { + case FlowMapping(ParameterNode(_, Some(srcName)), _) if node.argumentName.isDefined => + srcName == node.argumentName.get + case FlowMapping(ParameterNode(srcIndex, _), _) => srcIndex == node.argumentIndex + case PassThroughMapping if node.argumentIndex != 0 => true + case _ => false + }) - /** Determine whether evaluation of the call this argument is a part of results in definition of this argument. - */ - def isDefined(implicit semantics: Semantics): Boolean = { - val s = semanticsForCallByArg.l - s.isEmpty || s.exists { semantic => - semantic.mappings.exists { - case FlowMapping(_, ParameterNode(_, Some(dstName))) if node.argumentName.isDefined => - dstName == node.argumentName.get - case FlowMapping(_, ParameterNode(dstIndex, _)) => dstIndex == node.argumentIndex - case PassThroughMapping if node.argumentIndex != 0 => true - case _ => false - } - } - } - - /** Determines if this node and the given target node are arguments to the same call. - * @param other - * the node to compare - * @return - * true if these nodes are arguments to the same call, false if otherwise. - */ - def isArgToSameCallWith(other: Expression): Boolean = - node.astParent.start.collectAll[Call].headOption.equals(other.astParent.start.collectAll[Call].headOption) - - /** Determines if this node has a flow to the given target node in the defined semantics. - * @param tgt - * the target node to check. - * @param semantics - * the pre-defined flow semantics. - * @return - * true if there is flow defined between the two nodes, false if otherwise. - */ - def hasDefinedFlowTo(tgt: Expression)(implicit semantics: Semantics): Boolean = { - if (tgt.evalType(".*(?i)(int|float|double|boolean).*").nonEmpty) { - false - } else { - val s = semanticsForCallByArg.l - s.isEmpty || s.exists { semantic => - semantic.mappings.exists { - case FlowMapping(ParameterNode(_, Some(srcName)), ParameterNode(_, Some(dstName))) - if node.argumentName.isDefined && tgt.argumentName.isDefined => - srcName == node.argumentName.get && dstName == tgt.argumentName.get - case FlowMapping(ParameterNode(_, Some(srcName)), ParameterNode(dstIndex, _)) - if node.argumentName.isDefined => - srcName == node.argumentName.get && dstIndex == tgt.argumentIndex - case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(_, Some(dstName))) if tgt.argumentName.isDefined => - srcIndex == node.argumentIndex && dstName == tgt.argumentName.get - case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(dstIndex, _)) => - srcIndex == node.argumentIndex && dstIndex == tgt.argumentIndex - case PassThroughMapping if tgt.argumentIndex == node.argumentIndex || tgt.argumentIndex == -1 => true - case _ => false + /** Determine whether evaluation of the call this argument is a part of results in definition of + * this argument. + */ + def isDefined(implicit semantics: Semantics): Boolean = + val s = semanticsForCallByArg.l + s.isEmpty || s.exists { semantic => + semantic.mappings.exists { + case FlowMapping(_, ParameterNode(_, Some(dstName))) + if node.argumentName.isDefined => + dstName == node.argumentName.get + case FlowMapping(_, ParameterNode(dstIndex, _)) => dstIndex == node.argumentIndex + case PassThroughMapping if node.argumentIndex != 0 => true + case _ => false + } } - } - } - } - /** Retrieve flow semantic for the call this argument is a part of. - */ - def semanticsForCallByArg(implicit semantics: Semantics): Iterator[FlowSemantic] = { - argToMethods(node).flatMap { method => - semantics.forMethod(method.fullName) - } - } + /** Determines if this node and the given target node are arguments to the same call. + * @param other + * the node to compare + * @return + * true if these nodes are arguments to the same call, false if otherwise. + */ + def isArgToSameCallWith(other: Expression): Boolean = + node.astParent.start.collectAll[Call].headOption.equals( + other.astParent.start.collectAll[Call].headOption + ) - private def argToMethods(arg: Expression): Iterator[Method] = { - arg.inCall.flatMap { call => - if (call.nonEmpty) NoResolve.getCalledMethods(call) else mutable.ArrayBuffer.empty[Method] - } - } + /** Determines if this node has a flow to the given target node in the defined semantics. + * @param tgt + * the target node to check. + * @param semantics + * the pre-defined flow semantics. + * @return + * true if there is flow defined between the two nodes, false if otherwise. + */ + def hasDefinedFlowTo(tgt: Expression)(implicit semantics: Semantics): Boolean = + if tgt.evalType(".*(?i)(int|float|double|boolean).*").nonEmpty then + false + else + val s = semanticsForCallByArg.l + s.isEmpty || s.exists { semantic => + semantic.mappings.exists { + case FlowMapping( + ParameterNode(_, Some(srcName)), + ParameterNode(_, Some(dstName)) + ) + if node.argumentName.isDefined && tgt.argumentName.isDefined => + srcName == node.argumentName.get && dstName == tgt.argumentName.get + case FlowMapping(ParameterNode(_, Some(srcName)), ParameterNode(dstIndex, _)) + if node.argumentName.isDefined => + srcName == node.argumentName.get && dstIndex == tgt.argumentIndex + case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(_, Some(dstName))) + if tgt.argumentName.isDefined => + srcIndex == node.argumentIndex && dstName == tgt.argumentName.get + case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(dstIndex, _)) => + srcIndex == node.argumentIndex && dstIndex == tgt.argumentIndex + case PassThroughMapping + if tgt.argumentIndex == node.argumentIndex || tgt.argumentIndex == -1 => + true + case _ => false + } + } -} + /** Retrieve flow semantic for the call this argument is a part of. + */ + def semanticsForCallByArg(implicit semantics: Semantics): Iterator[FlowSemantic] = + argToMethods(node).flatMap { method => + semantics.forMethod(method.fullName) + } + + private def argToMethods(arg: Expression): Iterator[Method] = + arg.inCall.flatMap { call => + if call.nonEmpty then NoResolve.getCalledMethods(call) + else mutable.ArrayBuffer.empty[Method] + } +end ExpressionMethods diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala index 3a5475e7..e4d87428 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala @@ -11,81 +11,86 @@ import io.shiftleft.semanticcpg.language.* import scala.collection.mutable import scala.jdk.CollectionConverters.* -class ExtendedCfgNodeMethods[NodeType <: CfgNode](val node: NodeType) extends AnyVal { +class ExtendedCfgNodeMethods[NodeType <: CfgNode](val node: NodeType) extends AnyVal: - /** Convert to nearest AST node - */ - def astNode: AstNode = node + /** Convert to nearest AST node + */ + def astNode: AstNode = node - def reachableBy[NodeType](sourceTrav: Iterator[NodeType], sourceTravs: IterableOnce[NodeType]*)(implicit - context: EngineContext - ): Iterator[NodeType] = - node.start.reachableBy(sourceTrav, sourceTravs: _*) + def reachableBy[NodeType]( + sourceTrav: Iterator[NodeType], + sourceTravs: IterableOnce[NodeType]* + )(implicit context: EngineContext): Iterator[NodeType] = + node.start.reachableBy(sourceTrav, sourceTravs*) - def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = { - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = ddgIn(Vector(PathElement(node)), withInvisible = false, cache) - cache.clear() - result - } + def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = ddgIn(Vector(PathElement(node)), withInvisible = false, cache) + cache.clear() + result - def ddgInPathElem( - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] = mutable.HashMap[CfgNode, Vector[PathElement]]() - )(implicit semantics: Semantics): Iterator[PathElement] = - ddgInPathElem(Vector(PathElement(node)), withInvisible, cache) + def ddgInPathElem( + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] = + mutable.HashMap[CfgNode, Vector[PathElement]]() + )(implicit semantics: Semantics): Iterator[PathElement] = + ddgInPathElem(Vector(PathElement(node)), withInvisible, cache) - def ddgInPathElem(implicit semantics: Semantics): Iterator[PathElement] = { - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = ddgInPathElem(Vector(PathElement(node)), withInvisible = false, cache) - cache.clear() - result - } + def ddgInPathElem(implicit semantics: Semantics): Iterator[PathElement] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = ddgInPathElem(Vector(PathElement(node)), withInvisible = false, cache) + cache.clear() + result - /** Traverse back in the data dependence graph by one step, taking into account semantics - * @param path - * optional list of path elements that have been expanded already - */ - def ddgIn(path: Vector[PathElement], withInvisible: Boolean, cache: mutable.HashMap[CfgNode, Vector[PathElement]])( - implicit semantics: Semantics - ): Iterator[CfgNode] = { - ddgInPathElem(path, withInvisible, cache).map(_.node.asInstanceOf[CfgNode]) - } + /** Traverse back in the data dependence graph by one step, taking into account semantics + * @param path + * optional list of path elements that have been expanded already + */ + def ddgIn( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )( + implicit semantics: Semantics + ): Iterator[CfgNode] = + ddgInPathElem(path, withInvisible, cache).map(_.node.asInstanceOf[CfgNode]) - /** Traverse back in the data dependence graph by one step and generate corresponding PathElement, taking into account - * semantics - * @param path - * optional list of path elements that have been expanded already - */ - def ddgInPathElem( - path: Vector[PathElement], - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] - )(implicit semantics: Semantics): Iterator[PathElement] = { - val result = ddgInPathElemInternal(path, withInvisible, cache).iterator - result - } + /** Traverse back in the data dependence graph by one step and generate corresponding + * PathElement, taking into account semantics + * @param path + * optional list of path elements that have been expanded already + */ + def ddgInPathElem( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )(implicit semantics: Semantics): Iterator[PathElement] = + val result = ddgInPathElemInternal(path, withInvisible, cache).iterator + result - private def ddgInPathElemInternal( - path: Vector[PathElement], - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] - )(implicit semantics: Semantics): Vector[PathElement] = { + private def ddgInPathElemInternal( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )(implicit semantics: Semantics): Vector[PathElement] = - if (cache.contains(node)) { - return cache(node) - } + if cache.contains(node) then + return cache(node) - val elems = Engine.expandIn(node, path) - val result = if (withInvisible) { - elems - } else { - (elems.filter(_.visible) ++ elems - .filterNot(_.visible) - .flatMap(x => x.node.asInstanceOf[CfgNode].ddgInPathElem(x +: path, withInvisible = false, cache))).distinct - } - cache.put(node, result) - result - } - -} + val elems = Engine.expandIn(node, path) + val result = if withInvisible then + elems + else + (elems.filter(_.visible) ++ elems + .filterNot(_.visible) + .flatMap(x => + x.node.asInstanceOf[CfgNode].ddgInPathElem( + x +: path, + withInvisible = false, + cache + ) + )).distinct + cache.put(node, result) + result + end ddgInPathElemInternal +end ExtendedCfgNodeMethods diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala index e0cfea1b..399ac644 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala @@ -1,24 +1,29 @@ package io.appthreat.dataflowengineoss import io.appthreat.dataflowengineoss.language.dotextension.DdgNodeDot -import io.appthreat.dataflowengineoss.language.nodemethods.{ExpressionMethods, ExtendedCfgNodeMethods} +import io.appthreat.dataflowengineoss.language.nodemethods.{ + ExpressionMethods, + ExtendedCfgNodeMethods +} import io.shiftleft.codepropertygraph.generated.nodes.* -package object language { - - implicit def cfgNodeToMethodsQp[NodeType <: CfgNode](node: NodeType): ExtendedCfgNodeMethods[NodeType] = - new ExtendedCfgNodeMethods(node) +package object language: - implicit def expressionMethods[NodeType <: Expression](node: NodeType): ExpressionMethods[NodeType] = - new ExpressionMethods(node) + implicit def cfgNodeToMethodsQp[NodeType <: CfgNode](node: NodeType) + : ExtendedCfgNodeMethods[NodeType] = + new ExtendedCfgNodeMethods(node) - implicit def toExtendedCfgNode[NodeType <: CfgNode](traversal: IterableOnce[NodeType]): ExtendedCfgNode = - new ExtendedCfgNode(traversal.iterator) + implicit def expressionMethods[NodeType <: Expression](node: NodeType) + : ExpressionMethods[NodeType] = + new ExpressionMethods(node) - implicit def toDdgNodeDot(traversal: IterableOnce[Method]): DdgNodeDot = - new DdgNodeDot(traversal.iterator) + implicit def toExtendedCfgNode[NodeType <: CfgNode](traversal: IterableOnce[NodeType]) + : ExtendedCfgNode = + new ExtendedCfgNode(traversal.iterator) - implicit def toDdgNodeDotSingle(method: Method): DdgNodeDot = - new DdgNodeDot(Iterator.single(method)) + implicit def toDdgNodeDot(traversal: IterableOnce[Method]): DdgNodeDot = + new DdgNodeDot(traversal.iterator) -} + implicit def toDdgNodeDotSingle(method: Method): DdgNodeDot = + new DdgNodeDot(Iterator.single(method)) +end language diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala index 9d9ad959..a167d36b 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala @@ -9,25 +9,23 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, Layer case class Cpg14DumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpCpg14 { +object DumpCpg14: - val overlayName = "dumpCpg14" + val overlayName = "dumpCpg14" - val description = "Dump Code Property Graph (2014) to out/" + val description = "Dump Code Property Graph (2014) to out/" - def defaultOpts: Cpg14DumpOptions = Cpg14DumpOptions("out") -} + def defaultOpts: Cpg14DumpOptions = Cpg14DumpOptions("out") -class DumpCpg14(options: Cpg14DumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator { - override val overlayName: String = DumpDdg.overlayName - override val description: String = DumpDdg.description - override val storeOverlayName: Boolean = false +class DumpCpg14(options: Cpg14DumpOptions)(implicit semantics: Semantics = DefaultSemantics()) + extends LayerCreator: + override val overlayName: String = DumpDdg.overlayName + override val description: String = DumpDdg.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCpg14.head - (File(options.outDir) / s"$i-cpg.dot").write(str) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCpg14.head + (File(options.outDir) / s"$i-cpg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala index 70c38ce5..2da5800a 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala @@ -9,25 +9,23 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, Layer case class DdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpDdg { +object DumpDdg: - val overlayName = "dumpDdg" + val overlayName = "dumpDdg" - val description = "Dump data dependence graphs to out/" + val description = "Dump data dependence graphs to out/" - def defaultOpts: DdgDumpOptions = DdgDumpOptions("out") -} + def defaultOpts: DdgDumpOptions = DdgDumpOptions("out") -class DumpDdg(options: DdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator { - override val overlayName: String = DumpDdg.overlayName - override val description: String = DumpDdg.description - override val storeOverlayName: Boolean = false +class DumpDdg(options: DdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) + extends LayerCreator: + override val overlayName: String = DumpDdg.overlayName + override val description: String = DumpDdg.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotDdg.head - (File(options.outDir) / s"$i-ddg.dot").write(str) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotDdg.head + (File(options.outDir) / s"$i-ddg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala index f115b568..7b8f3aed 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala @@ -9,25 +9,23 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, Layer case class PdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpPdg { +object DumpPdg: - val overlayName = "dumpPdg" + val overlayName = "dumpPdg" - val description = "Dump program dependence graph to out/" + val description = "Dump program dependence graph to out/" - def defaultOpts: PdgDumpOptions = PdgDumpOptions("out") -} + def defaultOpts: PdgDumpOptions = PdgDumpOptions("out") -class DumpPdg(options: PdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator { - override val overlayName: String = DumpPdg.overlayName - override val description: String = DumpPdg.description - override val storeOverlayName: Boolean = false +class DumpPdg(options: PdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) + extends LayerCreator: + override val overlayName: String = DumpPdg.overlayName + override val description: String = DumpPdg.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotPdg.head - (File(options.outDir) / s"$i-pdg.dot").write(str) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotPdg.head + (File(options.outDir) / s"$i-pdg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala index ee9fd737..5a5db4be 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala @@ -5,12 +5,11 @@ import io.appthreat.dataflowengineoss.passes.reachingdef.ReachingDefPass import io.appthreat.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -object OssDataFlow { - val overlayName: String = "dataflowOss" - val description: String = "Layer to support the OSS lightweight data flow tracker" +object OssDataFlow: + val overlayName: String = "dataflowOss" + val description: String = "Layer to support the OSS lightweight data flow tracker" - def defaultOpts = new OssDataFlowOptions() -} + def defaultOpts = new OssDataFlowOptions() class OssDataFlowOptions( var maxNumberOfDefinitions: Int = 4000, @@ -19,16 +18,14 @@ class OssDataFlowOptions( class OssDataFlow(opts: OssDataFlowOptions)(implicit s: Semantics = Semantics.fromList(DefaultSemantics().elements ++ opts.extraFlows) -) extends LayerCreator { +) extends LayerCreator: - override val overlayName: String = OssDataFlow.overlayName - override val description: String = OssDataFlow.description + override val overlayName: String = OssDataFlow.overlayName + override val description: String = OssDataFlow.description - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - val enhancementExecList = Iterator(new ReachingDefPass(cpg, opts.maxNumberOfDefinitions)) - enhancementExecList.zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + val enhancementExecList = Iterator(new ReachingDefPass(cpg, opts.maxNumberOfDefinitions)) + enhancementExecList.zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala index 67c42184..aeb69881 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala @@ -3,19 +3,18 @@ package io.appthreat import io.shiftleft.codepropertygraph.generated.nodes.{Declaration, Expression, Identifier, Literal} import io.shiftleft.semanticcpg.language.* -package object dataflowengineoss { +package object dataflowengineoss: - def globalFromLiteral(lit: Literal): Iterator[Expression] = lit.start - .where(_.inAssignment.method.nameExact("", ":package")) - .inAssignment - .argument(1) + def globalFromLiteral(lit: Literal): Iterator[Expression] = lit.start + .where(_.inAssignment.method.nameExact("", ":package")) + .inAssignment + .argument(1) - def identifierToFirstUsages(node: Identifier): List[Identifier] = node.refsTo.flatMap(identifiersFromCapturedScopes).l + def identifierToFirstUsages(node: Identifier): List[Identifier] = + node.refsTo.flatMap(identifiersFromCapturedScopes).l - def identifiersFromCapturedScopes(i: Declaration): List[Identifier] = - i.capturedByMethodRef.referencedMethod.ast.isIdentifier - .nameExact(i.name) - .sortBy(x => (x.lineNumber, x.columnNumber)) - .l - -} + def identifiersFromCapturedScopes(i: Declaration): List[Identifier] = + i.capturedByMethodRef.referencedMethod.ast.isIdentifier + .nameExact(i.name) + .sortBy(x => (x.lineNumber, x.columnNumber)) + .l diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala index 7c0d3d86..207bf662 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala @@ -1,7 +1,8 @@ package io.appthreat.dataflowengineoss.passes.reachingdef -/** A general data flow problem, formulated as in the Dragon Book, Second Edition on page 626, with mild modifications. - * In particular, instead of allowing only for the specification of a boundary, we allow initialization of IN and OUT. +/** A general data flow problem, formulated as in the Dragon Book, Second Edition on page 626, with + * mild modifications. In particular, instead of allowing only for the specification of a boundary, + * we allow initialization of IN and OUT. */ class DataFlowProblem[Node, V]( val flowGraph: FlowGraph[Node], @@ -12,38 +13,36 @@ class DataFlowProblem[Node, V]( val empty: V ) -/** In essence, the flow graph is the control flow graph, however, we can compensate for small deviations from our - * generic control flow graph to one that is better suited for solving data flow problems. In particular, method - * parameters are not part of our normal control flow graph. By defining successors and predecessors, we provide a - * wrapper that takes care of these minor discrepancies. +/** In essence, the flow graph is the control flow graph, however, we can compensate for small + * deviations from our generic control flow graph to one that is better suited for solving data + * flow problems. In particular, method parameters are not part of our normal control flow graph. + * By defining successors and predecessors, we provide a wrapper that takes care of these minor + * discrepancies. */ -trait FlowGraph[Node] { - val allNodesReversePostOrder: List[Node] - val allNodesPostOrder: List[Node] - def succ(node: Node): IterableOnce[Node] - def pred(node: Node): IterableOnce[Node] -} - -/** This is actually a function family consisting of one transfer function for each node of the flow graph. Each - * function maps from the analysis domain to the analysis domain, e.g., for reaching definitions, sets of definitions - * are mapped to sets of definitions. +trait FlowGraph[Node]: + val allNodesReversePostOrder: List[Node] + val allNodesPostOrder: List[Node] + def succ(node: Node): IterableOnce[Node] + def pred(node: Node): IterableOnce[Node] + +/** This is actually a function family consisting of one transfer function for each node of the flow + * graph. Each function maps from the analysis domain to the analysis domain, e.g., for reaching + * definitions, sets of definitions are mapped to sets of definitions. */ -trait TransferFunction[Node, V] { - def apply(n: Node, x: V): V -} +trait TransferFunction[Node, V]: + def apply(n: Node, x: V): V -/** As a practical optimization, OUT[N] is often initialized to GEN[N]. Moreover, we need a way of specifying boundary - * conditions such as OUT[ENTRY] = {}. We achieve both by allowing the data flow problem to specify initializers for IN - * and OUT. +/** As a practical optimization, OUT[N] is often initialized to GEN[N]. Moreover, we need a way of + * specifying boundary conditions such as OUT[ENTRY] = {}. We achieve both by allowing the data + * flow problem to specify initializers for IN and OUT. */ -trait InOutInit[Node, V] { - - def initIn: Map[Node, V] +trait InOutInit[Node, V]: - def initOut: Map[Node, V] + def initIn: Map[Node, V] -} + def initOut: Map[Node, V] -/** The solution consists of `in` and `out` for each node of the flow graph. We also attach the problem. +/** The solution consists of `in` and `out` for each node of the flow graph. We also attach the + * problem. */ case class Solution[Node, V](in: Map[Node, V], out: Map[Node, V], problem: DataFlowProblem[Node, V]) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala index 5d31b848..4a91ab5c 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala @@ -2,74 +2,75 @@ package io.appthreat.dataflowengineoss.passes.reachingdef import scala.collection.mutable -class DataFlowSolver { +class DataFlowSolver: - /** Calculate fix point solution via a standard work list algorithm (Forwards). The result is given by two maps: `in` - * and `out`. These maps associate all CFG nodes with the set of definitions at node entry and node exit - * respectively. - */ - def calculateMopSolutionForwards[Node, T <: Iterable[_]](problem: DataFlowProblem[Node, T]): Solution[Node, T] = { - var out: Map[Node, T] = problem.inOutInit.initOut - var in = problem.inOutInit.initIn - val workList = mutable.ListBuffer[Node]() - workList ++= problem.flowGraph.allNodesReversePostOrder + /** Calculate fix point solution via a standard work list algorithm (Forwards). The result is + * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of + * definitions at node entry and node exit respectively. + */ + def calculateMopSolutionForwards[Node, T <: Iterable[_]](problem: DataFlowProblem[Node, T]) + : Solution[Node, T] = + var out: Map[Node, T] = problem.inOutInit.initOut + var in = problem.inOutInit.initIn + val workList = mutable.ListBuffer[Node]() + workList ++= problem.flowGraph.allNodesReversePostOrder - while (workList.nonEmpty) { - val newEntries = workList.flatMap { n => - val inSet = problem.flowGraph - .pred(n) - .iterator - .map(out) - .reduceOption((x, y) => problem.meet(x, y)) - .getOrElse(problem.empty) - in += n -> inSet - val old = out(n) - val newSet = problem.transferFunction(n, inSet) - val changed = !old.equals(newSet) - out += n -> newSet - if (changed) { - problem.flowGraph.succ(n) - } else - List() - } - workList.clear() - workList ++= newEntries.distinct - } - Solution(in, out, problem) - } + while workList.nonEmpty do + val newEntries = workList.flatMap { n => + val inSet = problem.flowGraph + .pred(n) + .iterator + .map(out) + .reduceOption((x, y) => problem.meet(x, y)) + .getOrElse(problem.empty) + in += n -> inSet + val old = out(n) + val newSet = problem.transferFunction(n, inSet) + val changed = !old.equals(newSet) + out += n -> newSet + if changed then + problem.flowGraph.succ(n) + else + List() + } + workList.clear() + workList ++= newEntries.distinct + end while + Solution(in, out, problem) + end calculateMopSolutionForwards - /** Calculate fix point solution via a standard work list algorithm (Backwards). The result is given by two maps: `in` - * and `out`. These maps associate all CFG nodes with the set of definitions at node entry and node exit - * respectively. - */ - def calculateMopSolutionBackwards[Node, T <: Iterable[_]](problem: DataFlowProblem[Node, T]): Solution[Node, T] = { - var out: Map[Node, T] = problem.inOutInit.initOut - var in = problem.inOutInit.initIn - val workList = mutable.ListBuffer[Node]() - workList ++= problem.flowGraph.allNodesPostOrder + /** Calculate fix point solution via a standard work list algorithm (Backwards). The result is + * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of + * definitions at node entry and node exit respectively. + */ + def calculateMopSolutionBackwards[Node, T <: Iterable[_]](problem: DataFlowProblem[Node, T]) + : Solution[Node, T] = + var out: Map[Node, T] = problem.inOutInit.initOut + var in = problem.inOutInit.initIn + val workList = mutable.ListBuffer[Node]() + workList ++= problem.flowGraph.allNodesPostOrder - while (workList.nonEmpty) { - val newEntries = workList.flatMap { n => - val outSet = problem.flowGraph - .succ(n) - .iterator - .map(in) - .reduceOption((x, y) => problem.meet(x, y)) - .getOrElse(problem.empty) - out += n -> outSet - val old = in(n) - val newSet = problem.transferFunction(n, outSet) - val changed = !old.equals(newSet) - in += n -> newSet - if (changed) - problem.flowGraph.pred(n) - else - List() - } - workList.clear() - workList ++= newEntries.distinct - } - Solution(in, out, problem) - } - -} + while workList.nonEmpty do + val newEntries = workList.flatMap { n => + val outSet = problem.flowGraph + .succ(n) + .iterator + .map(in) + .reduceOption((x, y) => problem.meet(x, y)) + .getOrElse(problem.empty) + out += n -> outSet + val old = in(n) + val newSet = problem.transferFunction(n, outSet) + val changed = !old.equals(newSet) + in += n -> newSet + if changed then + problem.flowGraph.pred(n) + else + List() + } + workList.clear() + workList ++= newEntries.distinct + end while + Solution(in, out, problem) + end calculateMopSolutionBackwards +end DataFlowSolver diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala index 9c17b433..081994a4 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala @@ -14,363 +14,360 @@ import scala.collection.{Set, mutable} /** Creation of data dependence edges based on solution of the ReachingDefProblem. */ -class DdgGenerator(semantics: Semantics) { - - implicit val s: Semantics = semantics - - /** Once reaching definitions have been computed, we create a data dependence graph by checking which reaching - * definitions are relevant, meaning that a symbol is propagated that is used by the target node. - * - * @param dstGraph - * the diff graph to add edges to - * @param problem - * the reaching definition problem - * @param solution - * the solution to `problem` - */ - def addReachingDefEdges( - dstGraph: DiffGraphBuilder, - method: Method, - problem: DataFlowProblem[StoredNode, mutable.BitSet], - solution: Solution[StoredNode, mutable.BitSet] - ): Unit = { - implicit val implicitDst: DiffGraphBuilder = dstGraph - - val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - val in = solution.in - val gen = solution.problem.transferFunction.asInstanceOf[ReachingDefTransferFunction].gen - - val allNodes = in.keys.toList - val usageAnalyzer = new UsageAnalyzer(problem, in) - - /** Add an edge from the entry node to each node that does not have other incoming definitions. +class DdgGenerator(semantics: Semantics): + + implicit val s: Semantics = semantics + + /** Once reaching definitions have been computed, we create a data dependence graph by checking + * which reaching definitions are relevant, meaning that a symbol is propagated that is used by + * the target node. + * + * @param dstGraph + * the diff graph to add edges to + * @param problem + * the reaching definition problem + * @param solution + * the solution to `problem` */ - def addEdgesFromEntryNode(): Unit = { - // Add edges from the entry node - allNodes - .filter(n => isDdgNode(n) && usageAnalyzer.usedIncomingDefs(n).isEmpty) - .foreach { node => - addEdge(method, node) - } - } - - // This handles `foo(new Bar()) or return new Bar()` - def addEdgeForBlock(block: Block, towards: StoredNode): Unit = { - block.astChildren.lastOption match { - case None => // Do nothing - case Some(node: Identifier) => - val edgesToAdd = in(node).toList - .flatMap(numberToNode.get) - .filter(inDef => usageAnalyzer.isUsing(node, inDef)) - .collect { - case identifier: Identifier => identifier - case call: Call => call + def addReachingDefEdges( + dstGraph: DiffGraphBuilder, + method: Method, + problem: DataFlowProblem[StoredNode, mutable.BitSet], + solution: Solution[StoredNode, mutable.BitSet] + ): Unit = + implicit val implicitDst: DiffGraphBuilder = dstGraph + + val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + val in = solution.in + val gen = solution.problem.transferFunction.asInstanceOf[ReachingDefTransferFunction].gen + + val allNodes = in.keys.toList + val usageAnalyzer = new UsageAnalyzer(problem, in) + + /** Add an edge from the entry node to each node that does not have other incoming + * definitions. + */ + def addEdgesFromEntryNode(): Unit = + // Add edges from the entry node + allNodes + .filter(n => isDdgNode(n) && usageAnalyzer.usedIncomingDefs(n).isEmpty) + .foreach { node => + addEdge(method, node) + } + + // This handles `foo(new Bar()) or return new Bar()` + def addEdgeForBlock(block: Block, towards: StoredNode): Unit = + block.astChildren.lastOption match + case None => // Do nothing + case Some(node: Identifier) => + val edgesToAdd = in(node).toList + .flatMap(numberToNode.get) + .filter(inDef => usageAnalyzer.isUsing(node, inDef)) + .collect { + case identifier: Identifier => identifier + case call: Call => call + } + edgesToAdd.foreach { inNode => + addEdge(inNode, block, nodeToEdgeLabel(inNode)) + } + if edgesToAdd.nonEmpty then + addEdge(block, towards) + case Some(node: Call) => + addEdge(node, block, nodeToEdgeLabel(node)) + addEdge(block, towards) + case _ => // Do nothing + /** Adds incoming edges to arguments of call sites, including edges between arguments of the + * same call site. + */ + def addEdgesToCallSite(call: Call): Unit = + // Edges between arguments of call sites + usageAnalyzer.usedIncomingDefs(call).foreach { case (use, ins) => + ins.foreach { in => + val inNode = numberToNode(in) + if inNode != use then + addEdge(inNode, use, nodeToEdgeLabel(inNode)) + } } - edgesToAdd.foreach { inNode => - addEdge(inNode, block, nodeToEdgeLabel(inNode)) - } - if (edgesToAdd.nonEmpty) { - addEdge(block, towards) - } - case Some(node: Call) => - addEdge(node, block, nodeToEdgeLabel(node)) - addEdge(block, towards) - case _ => // Do nothing - } - } - - /** Adds incoming edges to arguments of call sites, including edges between arguments of the same call site. - */ - def addEdgesToCallSite(call: Call): Unit = { - // Edges between arguments of call sites - usageAnalyzer.usedIncomingDefs(call).foreach { case (use, ins) => - ins.foreach { in => - val inNode = numberToNode(in) - if (inNode != use) { - addEdge(inNode, use, nodeToEdgeLabel(inNode)) - } - } - } - - // For all calls, assume that input arguments - // taint corresponding output arguments - // and the return value. We filter invalid - // edges at query time (according to the given semantic). - usageAnalyzer.uses(call).foreach { use => - gen(call).foreach { g => - val genNode = numberToNode(g) - if (use != genNode && isDdgNode(use)) { - addEdge(use, genNode, nodeToEdgeLabel(use)) - } - } - } - - // This handles `foo(new Bar())`, which is lowered to - // `foo({Bar tmp = Bar.alloc(); tmp.init(); tmp})` - call.argument.isBlock.foreach { block => addEdgeForBlock(block, call) } - } - - def addEdgesToReturn(ret: Return): Unit = { - // This handles `return new Bar()`, which is lowered to - // `return {Bar tmp = Bar.alloc(); tmp.init(); tmp}` - usageAnalyzer.uses(ret).collectAll[Block].foreach(block => addEdgeForBlock(block, ret)) - usageAnalyzer.usedIncomingDefs(ret).foreach { case (use: CfgNode, inElements) => - addEdge(use, ret, use.code) - inElements - .filterNot(x => numberToNode.get(x).contains(use)) - .flatMap(numberToNode.get) - .foreach { inElemNode => - addEdge(inElemNode, use, nodeToEdgeLabel(inElemNode)) - } - if (inElements.isEmpty) { - addEdge(method, ret) - } - } - addEdge(ret, method.methodReturn, "") - } - - def addEdgesToMethodParameterOut(paramOut: MethodParameterOut): Unit = { - // There is always an edge from the method input parameter - // to the corresponding method output parameter as modifications - // of the input parameter only affect a copy. - paramOut.paramIn.foreach { paramIn => - addEdge(paramIn, paramOut, paramIn.name) - } - usageAnalyzer.usedIncomingDefs(paramOut).foreach { case (_, inElements) => - inElements.foreach { inElement => - val inElemNode = numberToNode(inElement) - val edgeLabel = nodeToEdgeLabel(inElemNode) - addEdge(inElemNode, paramOut, edgeLabel) - } - } - } - - def addEdgesToExitNode(exitNode: MethodReturn): Unit = { - in(exitNode).foreach { i => - val iNode = numberToNode(i) - addEdge(iNode, exitNode, nodeToEdgeLabel(iNode)) - } - } - - /** This is part of the Lone-identifier optimization: as we remove lone identifiers from `gen` sets, we must now - * retrieve them and create an edge from each lone identifier to the exit node. - */ - def addEdgesFromLoneIdentifiersToExit(method: Method): Unit = { - val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - val exitNode = method.methodReturn - val transferFunction = solution.problem.transferFunction.asInstanceOf[OptimizedReachingDefTransferFunction] - val genOnce = transferFunction.loneIdentifiers - genOnce.foreach { case (_, defs) => - defs.foreach { d => - val dNode = numberToNode(d) - addEdge(dNode, exitNode, nodeToEdgeLabel(dNode)) - } - } - } - - def addEdgesToCapturedIdentifiersAndParameters(): Unit = { - val identifierDestPairs = - method._identifierViaContainsOut.flatMap { identifier => - val firstAndLastUsageByMethod = identifierToFirstUsages(identifier).groupBy(_.method) - firstAndLastUsageByMethod.values - .filter(_.nonEmpty) - .map(x => (x.head, x.last)) - .flatMap { case (firstUsage, lastUsage) => - (identifier.lineNumber, firstUsage.lineNumber, lastUsage.lineNumber) match { - case (Some(iNo), Some(fNo), _) if iNo <= fNo => Some(identifier, firstUsage) - case (Some(iNo), _, Some(lNo)) if iNo >= lNo => Some(lastUsage, identifier) - case _ => None - } + + // For all calls, assume that input arguments + // taint corresponding output arguments + // and the return value. We filter invalid + // edges at query time (according to the given semantic). + usageAnalyzer.uses(call).foreach { use => + gen(call).foreach { g => + val genNode = numberToNode(g) + if use != genNode && isDdgNode(use) then + addEdge(use, genNode, nodeToEdgeLabel(use)) + } } - }.distinct - identifierDestPairs - .foreach { case (src, dst) => - addEdge(src, dst, nodeToEdgeLabel(src)) - } - method.parameter.foreach { param => - param.capturedByMethodRef.referencedMethod.ast.isIdentifier.foreach { identifier => - addEdge(param, identifier, nodeToEdgeLabel(param)) - } - } - - val globalIdentifiers = - method.ast.isLiteral.flatMap(globalFromLiteral).collectAll[Identifier].l - globalIdentifiers - .foreach { global => - identifierToFirstUsages(global).map { identifier => - addEdge(global, identifier, nodeToEdgeLabel(global)) - } + // This handles `foo(new Bar())`, which is lowered to + // `foo({Bar tmp = Bar.alloc(); tmp.init(); tmp})` + call.argument.isBlock.foreach { block => addEdgeForBlock(block, call) } + end addEdgesToCallSite + + def addEdgesToReturn(ret: Return): Unit = + // This handles `return new Bar()`, which is lowered to + // `return {Bar tmp = Bar.alloc(); tmp.init(); tmp}` + usageAnalyzer.uses(ret).collectAll[Block].foreach(block => addEdgeForBlock(block, ret)) + usageAnalyzer.usedIncomingDefs(ret).foreach { case (use: CfgNode, inElements) => + addEdge(use, ret, use.code) + inElements + .filterNot(x => numberToNode.get(x).contains(use)) + .flatMap(numberToNode.get) + .foreach { inElemNode => + addEdge(inElemNode, use, nodeToEdgeLabel(inElemNode)) + } + if inElements.isEmpty then + addEdge(method, ret) + } + addEdge(ret, method.methodReturn, "") + + def addEdgesToMethodParameterOut(paramOut: MethodParameterOut): Unit = + // There is always an edge from the method input parameter + // to the corresponding method output parameter as modifications + // of the input parameter only affect a copy. + paramOut.paramIn.foreach { paramIn => + addEdge(paramIn, paramOut, paramIn.name) + } + usageAnalyzer.usedIncomingDefs(paramOut).foreach { case (_, inElements) => + inElements.foreach { inElement => + val inElemNode = numberToNode(inElement) + val edgeLabel = nodeToEdgeLabel(inElemNode) + addEdge(inElemNode, paramOut, edgeLabel) + } + } + + def addEdgesToExitNode(exitNode: MethodReturn): Unit = + in(exitNode).foreach { i => + val iNode = numberToNode(i) + addEdge(iNode, exitNode, nodeToEdgeLabel(iNode)) + } + + /** This is part of the Lone-identifier optimization: as we remove lone identifiers from + * `gen` sets, we must now retrieve them and create an edge from each lone identifier to + * the exit node. + */ + def addEdgesFromLoneIdentifiersToExit(method: Method): Unit = + val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + val exitNode = method.methodReturn + val transferFunction = + solution.problem.transferFunction.asInstanceOf[OptimizedReachingDefTransferFunction] + val genOnce = transferFunction.loneIdentifiers + genOnce.foreach { case (_, defs) => + defs.foreach { d => + val dNode = numberToNode(d) + addEdge(dNode, exitNode, nodeToEdgeLabel(dNode)) + } + } + + def addEdgesToCapturedIdentifiersAndParameters(): Unit = + val identifierDestPairs = + method._identifierViaContainsOut.flatMap { identifier => + val firstAndLastUsageByMethod = + identifierToFirstUsages(identifier).groupBy(_.method) + firstAndLastUsageByMethod.values + .filter(_.nonEmpty) + .map(x => (x.head, x.last)) + .flatMap { case (firstUsage, lastUsage) => + ( + identifier.lineNumber, + firstUsage.lineNumber, + lastUsage.lineNumber + ) match + case (Some(iNo), Some(fNo), _) if iNo <= fNo => + Some(identifier, firstUsage) + case (Some(iNo), _, Some(lNo)) if iNo >= lNo => + Some(lastUsage, identifier) + case _ => None + } + }.distinct + + identifierDestPairs + .foreach { case (src, dst) => + addEdge(src, dst, nodeToEdgeLabel(src)) + } + method.parameter.foreach { param => + param.capturedByMethodRef.referencedMethod.ast.isIdentifier.foreach { identifier => + addEdge(param, identifier, nodeToEdgeLabel(param)) + } + } + + val globalIdentifiers = + method.ast.isLiteral.flatMap(globalFromLiteral).collectAll[Identifier].l + globalIdentifiers + .foreach { global => + identifierToFirstUsages(global).map { identifier => + addEdge(global, identifier, nodeToEdgeLabel(global)) + } + } + end addEdgesToCapturedIdentifiersAndParameters + + addEdgesFromEntryNode() + allNodes.foreach { + case call: Call => addEdgesToCallSite(call) + case ret: Return => addEdgesToReturn(ret) + case paramOut: MethodParameterOut => addEdgesToMethodParameterOut(paramOut) + case _ => } - } - - addEdgesFromEntryNode() - allNodes.foreach { - case call: Call => addEdgesToCallSite(call) - case ret: Return => addEdgesToReturn(ret) - case paramOut: MethodParameterOut => addEdgesToMethodParameterOut(paramOut) - case _ => - } - - addEdgesToCapturedIdentifiersAndParameters() - addEdgesToExitNode(method.methodReturn) - addEdgesFromLoneIdentifiersToExit(method) - } - - private def addEdge(fromNode: StoredNode, toNode: StoredNode, variable: String = "")(implicit - dstGraph: DiffGraphBuilder - ): Unit = { - if (fromNode.isInstanceOf[Unknown] || toNode.isInstanceOf[Unknown]) - return - - (fromNode, toNode) match { - case (parentNode: CfgNode, childNode: CfgNode) if EdgeValidator.isValidEdge(childNode, parentNode) => - dstGraph.addEdge(fromNode, toNode, EdgeTypes.REACHING_DEF, PropertyNames.VARIABLE, variable) - case _ => - - } - } - - /** There are a few node types that (a) are not to be considered in the DDG, or (b) are not standalone DDG nodes, or - * (c) have a special meaning in the DDG. This function indicates whether the given node is just a regular DDG node - * instead. - */ - private def isDdgNode(x: StoredNode): Boolean = { - x match { - case _: Method => false - case _: ControlStructure => false - case _: FieldIdentifier => false - case _: JumpTarget => false - case _: MethodReturn => false - case _ => true - } - } - - private def nodeToEdgeLabel(node: StoredNode): String = { - node match { - case n: MethodParameterIn => n.name - case n: CfgNode => n.code - case _ => "" - } - } -} - -/** Upon calculating reaching definitions, we find ourselves with a set of incoming definitions `in(n)` for each node - * `n` of the flow graph. This component determines those of the incoming definitions that are relevant as the value - * they define is actually used by `n`. + + addEdgesToCapturedIdentifiersAndParameters() + addEdgesToExitNode(method.methodReturn) + addEdgesFromLoneIdentifiersToExit(method) + end addReachingDefEdges + + private def addEdge(fromNode: StoredNode, toNode: StoredNode, variable: String = "")(implicit + dstGraph: DiffGraphBuilder + ): Unit = + if fromNode.isInstanceOf[Unknown] || toNode.isInstanceOf[Unknown] then + return + + (fromNode, toNode) match + case (parentNode: CfgNode, childNode: CfgNode) + if EdgeValidator.isValidEdge(childNode, parentNode) => + dstGraph.addEdge( + fromNode, + toNode, + EdgeTypes.REACHING_DEF, + PropertyNames.VARIABLE, + variable + ) + case _ => + + /** There are a few node types that (a) are not to be considered in the DDG, or (b) are not + * standalone DDG nodes, or (c) have a special meaning in the DDG. This function indicates + * whether the given node is just a regular DDG node instead. + */ + private def isDdgNode(x: StoredNode): Boolean = + x match + case _: Method => false + case _: ControlStructure => false + case _: FieldIdentifier => false + case _: JumpTarget => false + case _: MethodReturn => false + case _ => true + + private def nodeToEdgeLabel(node: StoredNode): String = + node match + case n: MethodParameterIn => n.name + case n: CfgNode => n.code + case _ => "" +end DdgGenerator + +/** Upon calculating reaching definitions, we find ourselves with a set of incoming definitions + * `in(n)` for each node `n` of the flow graph. This component determines those of the incoming + * definitions that are relevant as the value they define is actually used by `n`. */ private class UsageAnalyzer( problem: DataFlowProblem[StoredNode, mutable.BitSet], in: Map[StoredNode, Set[Definition]] -) { - - val numberToNode: Map[Definition, StoredNode] = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - - private val allNodes = in.keys.toList - private val containerSet = - Set(Operators.fieldAccess, Operators.indexAccess, Operators.indirectIndexAccess, Operators.indirectFieldAccess) - private val indirectionAccessSet = Set(Operators.addressOf, Operators.indirection) - val usedIncomingDefs: Map[StoredNode, Map[StoredNode, Set[Definition]]] = initUsedIncomingDefs() - - def initUsedIncomingDefs(): Map[StoredNode, Map[StoredNode, Set[Definition]]] = { - allNodes.map { node => - node -> usedIncomingDefsForNode(node) - }.toMap - } - - private def usedIncomingDefsForNode(node: StoredNode): Map[StoredNode, Set[Definition]] = { - uses(node).map { use => - use -> in(node).filter { inElement => - val inElemNode = numberToNode(inElement) - isUsing(use, inElemNode) - } - }.toMap - } - - def isUsing(use: StoredNode, inElemNode: StoredNode): Boolean = - sameVariable(use, inElemNode) || isContainer(use, inElemNode) || isPart(use, inElemNode) || isAlias(use, inElemNode) - - /** Determine whether the node `use` describes a container for `inElement`, e.g., use = `ptr` while inElement = - * `ptr->foo`. - */ - private def isContainer(use: StoredNode, inElement: StoredNode): Boolean = { - inElement match { - case call: Call if containerSet.contains(call.name) => - call.argument.headOption.exists { base => - nodeToString(use) == nodeToString(base) - } - case _ => false - } - } - - /** Determine whether `use` is a part of `inElement`, e.g., use = `argv[0]` while inElement = `argv` - */ - private def isPart(use: StoredNode, inElement: StoredNode): Boolean = { - use match { - case call: Call if containerSet.contains(call.name) => - inElement match { - case param: MethodParameterIn => - call.argument.headOption.exists { base => - nodeToString(base).contains(param.name) +): + + val numberToNode: Map[Definition, StoredNode] = + problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + + private val allNodes = in.keys.toList + private val containerSet = + Set( + Operators.fieldAccess, + Operators.indexAccess, + Operators.indirectIndexAccess, + Operators.indirectFieldAccess + ) + private val indirectionAccessSet = Set(Operators.addressOf, Operators.indirection) + val usedIncomingDefs: Map[StoredNode, Map[StoredNode, Set[Definition]]] = initUsedIncomingDefs() + + def initUsedIncomingDefs(): Map[StoredNode, Map[StoredNode, Set[Definition]]] = + allNodes.map { node => + node -> usedIncomingDefsForNode(node) + }.toMap + + private def usedIncomingDefsForNode(node: StoredNode): Map[StoredNode, Set[Definition]] = + uses(node).map { use => + use -> in(node).filter { inElement => + val inElemNode = numberToNode(inElement) + isUsing(use, inElemNode) } - case identifier: Identifier => - call.argument.headOption.exists { base => - nodeToString(base).contains(identifier.name) - } - case _ => false - } - case _ => false - } - } - - private def isAlias(use: StoredNode, inElement: StoredNode): Boolean = { - use match { - case useCall: Call => - inElement match { - case inCall: Call => - val (useBase, useAccessPath) = toTrackedBaseAndAccessPathSimple(useCall) - val (inBase, inAccessPath) = toTrackedBaseAndAccessPathSimple(inCall) - useBase == inBase && useAccessPath.matchAndDiff(inAccessPath.elements)._1 == MatchResult.EXACT_MATCH - case _ => false - } - case _ => false - } - } - - def uses(node: StoredNode): Set[StoredNode] = { - val n: Set[StoredNode] = node match { - case ret: Return => ret.astChildren.collect { case x: Expression => x }.toSet - case call: Call => call.argument.toSet - case paramOut: MethodParameterOut => Set(paramOut) - case _ => Set() - } - n.filterNot(_.isInstanceOf[FieldIdentifier]) - } - - /** Compares arguments of calls with incoming definitions to see if they refer to the same variable - */ - private def sameVariable(use: StoredNode, inElement: StoredNode): Boolean = { - inElement match { - case param: MethodParameterIn => - nodeToString(use).contains(param.name) || nodeToString(use).contains(param.code) - case call: Call if indirectionAccessSet.contains(call.name) => - call.argumentOption(1).exists(x => nodeToString(use).contains(x.code)) - case call: Call => - nodeToString(use).contains(call.code) - case identifier: Identifier => - nodeToString(use).contains(identifier.name) || nodeToString(use).contains(identifier.code) - case _ => false - } - } - - private def nodeToString(node: StoredNode): Option[String] = { - node match { - case ident: Identifier => Some(ident.name) - case exp: Expression => Some(exp.code) - case p: MethodParameterIn => Some(p.name) - case p: MethodParameterOut => Some(p.name) - case _ => None - } - } - -} + }.toMap + + def isUsing(use: StoredNode, inElemNode: StoredNode): Boolean = + sameVariable(use, inElemNode) || isContainer(use, inElemNode) || isPart( + use, + inElemNode + ) || isAlias(use, inElemNode) + + /** Determine whether the node `use` describes a container for `inElement`, e.g., use = `ptr` + * while inElement = `ptr->foo`. + */ + private def isContainer(use: StoredNode, inElement: StoredNode): Boolean = + inElement match + case call: Call if containerSet.contains(call.name) => + call.argument.headOption.exists { base => + nodeToString(use) == nodeToString(base) + } + case _ => false + + /** Determine whether `use` is a part of `inElement`, e.g., use = `argv[0]` while inElement = + * `argv` + */ + private def isPart(use: StoredNode, inElement: StoredNode): Boolean = + use match + case call: Call if containerSet.contains(call.name) => + inElement match + case param: MethodParameterIn => + call.argument.headOption.exists { base => + nodeToString(base).contains(param.name) + } + case identifier: Identifier => + call.argument.headOption.exists { base => + nodeToString(base).contains(identifier.name) + } + case _ => false + case _ => false + + private def isAlias(use: StoredNode, inElement: StoredNode): Boolean = + use match + case useCall: Call => + inElement match + case inCall: Call => + val (useBase, useAccessPath) = toTrackedBaseAndAccessPathSimple(useCall) + val (inBase, inAccessPath) = toTrackedBaseAndAccessPathSimple(inCall) + useBase == inBase && useAccessPath.matchAndDiff( + inAccessPath.elements + )._1 == MatchResult.EXACT_MATCH + case _ => false + case _ => false + + def uses(node: StoredNode): Set[StoredNode] = + val n: Set[StoredNode] = node match + case ret: Return => ret.astChildren.collect { case x: Expression => x }.toSet + case call: Call => call.argument.toSet + case paramOut: MethodParameterOut => Set(paramOut) + case _ => Set() + n.filterNot(_.isInstanceOf[FieldIdentifier]) + + /** Compares arguments of calls with incoming definitions to see if they refer to the same + * variable + */ + private def sameVariable(use: StoredNode, inElement: StoredNode): Boolean = + inElement match + case param: MethodParameterIn => + nodeToString(use).contains(param.name) || nodeToString(use).contains(param.code) + case call: Call if indirectionAccessSet.contains(call.name) => + call.argumentOption(1).exists(x => nodeToString(use).contains(x.code)) + case call: Call => + nodeToString(use).contains(call.code) + case identifier: Identifier => + nodeToString(use).contains(identifier.name) || nodeToString(use).contains( + identifier.code + ) + case _ => false + + private def nodeToString(node: StoredNode): Option[String] = + node match + case ident: Identifier => Some(ident.name) + case exp: Expression => Some(exp.code) + case p: MethodParameterIn => Some(p.name) + case p: MethodParameterOut => Some(p.name) + case _ => None +end UsageAnalyzer diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala index 8bc61b08..e9b863e6 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala @@ -3,49 +3,57 @@ package io.appthreat.dataflowengineoss.passes.reachingdef import io.appthreat.dataflowengineoss.semanticsloader.{PassThroughMapping, Semantics} import io.appthreat.dataflowengineoss.language.* import io.appthreat.dataflowengineoss.queryengine.Engine.isOutputArgOfInternalMethod -import io.appthreat.dataflowengineoss.semanticsloader.{FlowMapping, ParameterNode, PassThroughMapping, Semantics} +import io.appthreat.dataflowengineoss.semanticsloader.{ + FlowMapping, + ParameterNode, + PassThroughMapping, + Semantics +} import io.shiftleft.codepropertygraph.generated.nodes.{Call, CfgNode, Expression, StoredNode} import io.shiftleft.semanticcpg.language.* -object EdgeValidator { - - /** Determines whether the edge from `parentNode`to `childNode` is valid, according to the given semantics. - */ - def isValidEdge(childNode: CfgNode, parentNode: CfgNode)(implicit semantics: Semantics): Boolean = - (childNode, parentNode) match { - case (childNode: Expression, parentNode) - if isCallRetval(parentNode) || !isValidEdgeToExpression(parentNode, childNode) => - false - case (childNode: Expression, parentNode: Expression) - if parentNode.isArgToSameCallWith(childNode) && childNode.isDefined && parentNode.isUsed => - parentNode.hasDefinedFlowTo(childNode) - case (_: Expression, _: Expression) => true - case (childNode: Expression, _) if !childNode.isUsed => false - case (_: Expression, _) => true - case (_, parentNode) => !isCallRetval(parentNode) - } +object EdgeValidator: - private def isValidEdgeToExpression(parNode: CfgNode, curNode: Expression)(implicit semantics: Semantics): Boolean = - parNode match { - case parentNode: Expression => - val sameCallSite = parentNode.inCall.l == curNode.start.inCall.l - !(sameCallSite && isOutputArgOfInternalMethod(parentNode)) && - (sameCallSite && parentNode.isUsed && curNode.isDefined || !sameCallSite && curNode.isUsed) - case _ => - curNode.isUsed - } + /** Determines whether the edge from `parentNode`to `childNode` is valid, according to the given + * semantics. + */ + def isValidEdge(childNode: CfgNode, parentNode: CfgNode)(implicit + semantics: Semantics + ): Boolean = + (childNode, parentNode) match + case (childNode: Expression, parentNode) + if isCallRetval(parentNode) || !isValidEdgeToExpression(parentNode, childNode) => + false + case (childNode: Expression, parentNode: Expression) + if parentNode.isArgToSameCallWith( + childNode + ) && childNode.isDefined && parentNode.isUsed => + parentNode.hasDefinedFlowTo(childNode) + case (_: Expression, _: Expression) => true + case (childNode: Expression, _) if !childNode.isUsed => false + case (_: Expression, _) => true + case (_, parentNode) => !isCallRetval(parentNode) - private def isCallRetval(parentNode: StoredNode)(implicit semantics: Semantics): Boolean = - parentNode match { - case call: Call => - val sem = semantics.forMethod(call.methodFullName) - sem.isDefined && !sem.get.mappings.exists { - case FlowMapping(_, ParameterNode(dst, _)) => dst == -1 - case PassThroughMapping => true - case _ => false - } - case _ => - false - } + private def isValidEdgeToExpression(parNode: CfgNode, curNode: Expression)(implicit + semantics: Semantics + ): Boolean = + parNode match + case parentNode: Expression => + val sameCallSite = parentNode.inCall.l == curNode.start.inCall.l + !(sameCallSite && isOutputArgOfInternalMethod(parentNode)) && + (sameCallSite && parentNode.isUsed && curNode.isDefined || !sameCallSite && curNode.isUsed) + case _ => + curNode.isUsed -} + private def isCallRetval(parentNode: StoredNode)(implicit semantics: Semantics): Boolean = + parentNode match + case call: Call => + val sem = semantics.forMethod(call.methodFullName) + sem.isDefined && !sem.get.mappings.exists { + case FlowMapping(_, ParameterNode(dst, _)) => dst == -1 + case PassThroughMapping => true + case _ => false + } + case _ => + false +end EdgeValidator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala index b11e979a..83f6f75c 100755 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala @@ -13,43 +13,46 @@ import scala.collection.mutable /** A pass that calculates reaching definitions ("data dependencies"). */ class ReachingDefPass(cpg: Cpg, maxNumberOfDefinitions: Int = 4000)(implicit s: Semantics) - extends ForkJoinParallelCpgPass[Method](cpg) { - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - // If there are any regex method full names, load them early - s.loadRegexSemantics(cpg) - - override def generateParts(): Array[Method] = cpg.method.toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = { - logger.debug("Calculating reaching definitions for: {} in {}", method.fullName, method.filename) - val problem = ReachingDefProblem.create(method) - if (shouldBailOut(method, problem)) { - logger.warn("Skipping.") - return - } - - val solution = new DataFlowSolver().calculateMopSolutionForwards(problem) - val ddgGenerator = new DdgGenerator(s) - ddgGenerator.addReachingDefEdges(dstGraph, method, problem, solution) - } - - /** Before we start propagating definitions in the graph, which is the bulk of the work, we check how many definitions - * were are dealing with in total. If a threshold is reached, we bail out instead, leaving reaching definitions - * uncalculated for the method in question. Users can increase the threshold if desired. - */ - private def shouldBailOut(method: Method, problem: DataFlowProblem[StoredNode, mutable.BitSet]): Boolean = { - val transferFunction = problem.transferFunction.asInstanceOf[ReachingDefTransferFunction] - // For each node, the `gen` map contains the list of definitions it generates - // We add up the sizes of these lists to obtain the total number of definitions - val numberOfDefinitions = transferFunction.gen.foldLeft(0)(_ + _._2.size) - logger.debug("Number of definitions for {}: {}", method.fullName, numberOfDefinitions) - if (numberOfDefinitions > maxNumberOfDefinitions) { - logger.warn("{} has more than {} definitions", method.fullName, maxNumberOfDefinitions) - true - } else { - false - } - } - -} + extends ForkJoinParallelCpgPass[Method](cpg): + + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + // If there are any regex method full names, load them early + s.loadRegexSemantics(cpg) + + override def generateParts(): Array[Method] = cpg.method.toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + logger.debug( + "Calculating reaching definitions for: {} in {}", + method.fullName, + method.filename + ) + val problem = ReachingDefProblem.create(method) + if shouldBailOut(method, problem) then + logger.warn("Skipping.") + return + + val solution = new DataFlowSolver().calculateMopSolutionForwards(problem) + val ddgGenerator = new DdgGenerator(s) + ddgGenerator.addReachingDefEdges(dstGraph, method, problem, solution) + + /** Before we start propagating definitions in the graph, which is the bulk of the work, we + * check how many definitions were are dealing with in total. If a threshold is reached, we + * bail out instead, leaving reaching definitions uncalculated for the method in question. + * Users can increase the threshold if desired. + */ + private def shouldBailOut( + method: Method, + problem: DataFlowProblem[StoredNode, mutable.BitSet] + ): Boolean = + val transferFunction = problem.transferFunction.asInstanceOf[ReachingDefTransferFunction] + // For each node, the `gen` map contains the list of definitions it generates + // We add up the sizes of these lists to obtain the total number of definitions + val numberOfDefinitions = transferFunction.gen.foldLeft(0)(_ + _._2.size) + logger.debug("Number of definitions for {}: {}", method.fullName, numberOfDefinitions) + if numberOfDefinitions > maxNumberOfDefinitions then + logger.warn("{} has more than {} definitions", method.fullName, maxNumberOfDefinitions) + true + else + false +end ReachingDefPass diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala index 3ef12f4a..10816445 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala @@ -1,360 +1,366 @@ package io.appthreat.dataflowengineoss.passes.reachingdef import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.utils.MemberAccess.{isFieldAccess, isGenericMemberAccessName} import org.slf4j.{Logger, LoggerFactory} import scala.collection.{Set, mutable} -/** The variables defined/used in the reaching def problem can all be represented via nodes in the graph, however, - * that's pretty confusing because it is then unclear that variables and nodes are actually two separate domains. To - * make the definition domain visible, we use the type alias `Definition`. From a computational standpoint, this is not - * necessary, but it greatly improves readability. +/** The variables defined/used in the reaching def problem can all be represented via nodes in the + * graph, however, that's pretty confusing because it is then unclear that variables and nodes are + * actually two separate domains. To make the definition domain visible, we use the type alias + * `Definition`. From a computational standpoint, this is not necessary, but it greatly improves + * readability. */ -object Definition { - def fromNode(node: StoredNode, nodeToNumber: Map[StoredNode, Int]): Definition = { - nodeToNumber(node) - } -} +object Definition: + def fromNode(node: StoredNode, nodeToNumber: Map[StoredNode, Int]): Definition = + nodeToNumber(node) + +object ReachingDefProblem: + def create(method: Method): DataFlowProblem[StoredNode, mutable.BitSet] = + val flowGraph = new ReachingDefFlowGraph(method) + val transfer = new OptimizedReachingDefTransferFunction(flowGraph) + val init = new ReachingDefInit(transfer.gen) + def meet: (mutable.BitSet, mutable.BitSet) => mutable.BitSet = + (x: mutable.BitSet, y: mutable.BitSet) => x.union(y) + + new DataFlowProblem[StoredNode, mutable.BitSet]( + flowGraph, + transfer, + meet, + init, + true, + mutable.BitSet() + ) -object ReachingDefProblem { - def create(method: Method): DataFlowProblem[StoredNode, mutable.BitSet] = { - val flowGraph = new ReachingDefFlowGraph(method) - val transfer = new OptimizedReachingDefTransferFunction(flowGraph) - val init = new ReachingDefInit(transfer.gen) - def meet: (mutable.BitSet, mutable.BitSet) => mutable.BitSet = - (x: mutable.BitSet, y: mutable.BitSet) => { x.union(y) } +/** The control flow graph as viewed by the data flow solver. + */ +class ReachingDefFlowGraph(val method: Method) extends FlowGraph[StoredNode]: - new DataFlowProblem[StoredNode, mutable.BitSet](flowGraph, transfer, meet, init, true, mutable.BitSet()) - } + private val logger: Logger = LoggerFactory.getLogger(this.getClass) -} + val entryNode: StoredNode = method + val exitNode: StoredNode = method.methodReturn -/** The control flow graph as viewed by the data flow solver. - */ -class ReachingDefFlowGraph(val method: Method) extends FlowGraph[StoredNode] { - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - val entryNode: StoredNode = method - val exitNode: StoredNode = method.methodReturn - - private val params = method.parameter.sortBy(_.index) - private val firstParam = params.headOption - private val lastParam = params.lastOption - private val firstOutputParam = firstParam.flatMap(_.asOutput.headOption) - private val lastOutputParam = method.parameter.sortBy(_.index).asOutput.lastOption - - private val lastActualCfgNode = exitNode._cfgIn.nextOption() - - val allNodesReversePostOrder: List[StoredNode] = - List(entryNode) ++ method.parameter.toList ++ method.reversePostOrder.toList.filter(x => - x.id != entryNode.id && x.id != exitNode.id - ) ++ method.parameter.asOutput.toList ++ List(exitNode) - - private val allNodesEvenUnreachable = - allNodesReversePostOrder ++ method.cfgNode.l.filterNot(x => allNodesReversePostOrder.contains(x)) - val nodeToNumber: Map[StoredNode, Int] = allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => x -> i }.toMap - val numberToNode: Map[Int, StoredNode] = allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => i -> x }.toMap - - val allNodesPostOrder: List[StoredNode] = allNodesReversePostOrder.reverse - - private val _succ: Map[StoredNode, List[StoredNode]] = initSucc(allNodesReversePostOrder) - private val _pred: Map[StoredNode, List[StoredNode]] = initPred(allNodesReversePostOrder, method) - - override def succ(node: StoredNode): IterableOnce[StoredNode] = { - _succ.apply(node) - } - - override def pred(node: StoredNode): IterableOnce[StoredNode] = { - _pred.apply(node) - } - - /** Create a map that allows CFG successors to be retrieved for each node - */ - private def initSucc(ns: List[StoredNode]): Map[StoredNode, List[StoredNode]] = { - ns.map { - case n: Method => n -> firstParamOrBody(n) - case ret: Return => ret -> List(firstOutputParam.getOrElse(exitNode)) - case param: MethodParameterIn => - param -> nextParamOrBody(param) - case paramOut: MethodParameterOut => paramOut -> nextParamOutOrExit(paramOut) - case cfgNode: CfgNode => cfgNode -> cfgNextOrFirstOutParam(cfgNode) - case n => - logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") - n -> List() - }.toMap - } - - /** Create a map that allows CFG predecessors to be retrieved for each node - */ - private def initPred(ns: List[StoredNode], method: Method): Map[StoredNode, List[StoredNode]] = { - ns.map { - case param: MethodParameterIn => param -> previousParamOrEntry(param) - case paramOut: MethodParameterOut => paramOut -> previousOutputParamOrLastNodeOfBody(paramOut) - case n: CfgNode if method.cfgFirst.headOption.contains(n) => n -> List(lastParam.getOrElse(method)) - case n if n == exitNode => n -> lastOutputParamOrLastNodeOfBody() - case n @ (cfgNode: CfgNode) => n -> cfgNode.cfgPrev.l - case n => - logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") - n -> List() - }.toMap - } - - private def firstParamOrBody(n: Method): List[StoredNode] = { - if (firstParam.isDefined) { - firstParam.toList - } else { - cfgNext(n) - } - } - - private def cfgNext(n: CfgNode): List[StoredNode] = - n.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l - - private def nextParamOrBody(param: MethodParameterIn): List[StoredNode] = { - val nextParam = param.method.parameter.index(param.index + 1).headOption - if (nextParam.isDefined) { nextParam.toList } - else { param.method.cfgFirst.l } - } - - private def nextParamOutOrExit(paramOut: MethodParameterOut): List[StoredNode] = { - val nextParam = paramOut.method.parameter.index(paramOut.index + 1).asOutput.headOption - if (nextParam.isDefined) { nextParam.toList } - else { List(exitNode) } - } - - private def cfgNextOrFirstOutParam(cfgNode: CfgNode): List[StoredNode] = { - // `.cfgNext` would be wrong here because it filters `METHOD_RETURN` - val successors = cfgNode.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l - if (successors == List(exitNode) && firstOutputParam.isDefined) { - List(firstOutputParam.get) - } else { - successors - } - } - - private def previousParamOrEntry(param: MethodParameterIn): List[StoredNode] = { - val prevParam = param.method.parameter.index(param.index - 1).headOption - if (prevParam.isDefined) { prevParam.toList } - else { List(method) } - } - - private def previousOutputParamOrLastNodeOfBody(paramOut: MethodParameterOut): List[StoredNode] = { - val prevParam = paramOut.method.parameter.index(paramOut.index - 1).asOutput.headOption - if (prevParam.isDefined) { prevParam.toList } - else { lastActualCfgNode.toList } - } - - private def lastOutputParamOrLastNodeOfBody(): List[StoredNode] = { - if (lastOutputParam.isDefined) { lastOutputParam.toList } - else { lastActualCfgNode.toList } - } - -} - -/** For each node of the graph, this transfer function defines how it affects the propagation of definitions. + private val params = method.parameter.sortBy(_.index) + private val firstParam = params.headOption + private val lastParam = params.lastOption + private val firstOutputParam = firstParam.flatMap(_.asOutput.headOption) + private val lastOutputParam = method.parameter.sortBy(_.index).asOutput.lastOption + + private val lastActualCfgNode = exitNode._cfgIn.nextOption() + + val allNodesReversePostOrder: List[StoredNode] = + List(entryNode) ++ method.parameter.toList ++ method.reversePostOrder.toList.filter(x => + x.id != entryNode.id && x.id != exitNode.id + ) ++ method.parameter.asOutput.toList ++ List(exitNode) + + private val allNodesEvenUnreachable = + allNodesReversePostOrder ++ method.cfgNode.l.filterNot(x => + allNodesReversePostOrder.contains(x) + ) + val nodeToNumber: Map[StoredNode, Int] = + allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => x -> i }.toMap + val numberToNode: Map[Int, StoredNode] = + allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => i -> x }.toMap + + val allNodesPostOrder: List[StoredNode] = allNodesReversePostOrder.reverse + + private val _succ: Map[StoredNode, List[StoredNode]] = initSucc(allNodesReversePostOrder) + private val _pred: Map[StoredNode, List[StoredNode]] = + initPred(allNodesReversePostOrder, method) + + override def succ(node: StoredNode): IterableOnce[StoredNode] = + _succ.apply(node) + + override def pred(node: StoredNode): IterableOnce[StoredNode] = + _pred.apply(node) + + /** Create a map that allows CFG successors to be retrieved for each node + */ + private def initSucc(ns: List[StoredNode]): Map[StoredNode, List[StoredNode]] = + ns.map { + case n: Method => n -> firstParamOrBody(n) + case ret: Return => ret -> List(firstOutputParam.getOrElse(exitNode)) + case param: MethodParameterIn => + param -> nextParamOrBody(param) + case paramOut: MethodParameterOut => paramOut -> nextParamOutOrExit(paramOut) + case cfgNode: CfgNode => cfgNode -> cfgNextOrFirstOutParam(cfgNode) + case n => + logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") + n -> List() + }.toMap + + /** Create a map that allows CFG predecessors to be retrieved for each node + */ + private def initPred(ns: List[StoredNode], method: Method): Map[StoredNode, List[StoredNode]] = + ns.map { + case param: MethodParameterIn => param -> previousParamOrEntry(param) + case paramOut: MethodParameterOut => + paramOut -> previousOutputParamOrLastNodeOfBody(paramOut) + case n: CfgNode if method.cfgFirst.headOption.contains(n) => + n -> List(lastParam.getOrElse(method)) + case n if n == exitNode => n -> lastOutputParamOrLastNodeOfBody() + case n @ (cfgNode: CfgNode) => n -> cfgNode.cfgPrev.l + case n => + logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") + n -> List() + }.toMap + + private def firstParamOrBody(n: Method): List[StoredNode] = + if firstParam.isDefined then + firstParam.toList + else + cfgNext(n) + + private def cfgNext(n: CfgNode): List[StoredNode] = + n.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l + + private def nextParamOrBody(param: MethodParameterIn): List[StoredNode] = + val nextParam = param.method.parameter.index(param.index + 1).headOption + if nextParam.isDefined then nextParam.toList + else param.method.cfgFirst.l + + private def nextParamOutOrExit(paramOut: MethodParameterOut): List[StoredNode] = + val nextParam = paramOut.method.parameter.index(paramOut.index + 1).asOutput.headOption + if nextParam.isDefined then nextParam.toList + else List(exitNode) + + private def cfgNextOrFirstOutParam(cfgNode: CfgNode): List[StoredNode] = + // `.cfgNext` would be wrong here because it filters `METHOD_RETURN` + val successors = cfgNode.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l + if successors == List(exitNode) && firstOutputParam.isDefined then + List(firstOutputParam.get) + else + successors + + private def previousParamOrEntry(param: MethodParameterIn): List[StoredNode] = + val prevParam = param.method.parameter.index(param.index - 1).headOption + if prevParam.isDefined then prevParam.toList + else List(method) + + private def previousOutputParamOrLastNodeOfBody(paramOut: MethodParameterOut) + : List[StoredNode] = + val prevParam = paramOut.method.parameter.index(paramOut.index - 1).asOutput.headOption + if prevParam.isDefined then prevParam.toList + else lastActualCfgNode.toList + + private def lastOutputParamOrLastNodeOfBody(): List[StoredNode] = + if lastOutputParam.isDefined then lastOutputParam.toList + else lastActualCfgNode.toList +end ReachingDefFlowGraph + +/** For each node of the graph, this transfer function defines how it affects the propagation of + * definitions. */ class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) - extends TransferFunction[StoredNode, mutable.BitSet] { + extends TransferFunction[StoredNode, mutable.BitSet]: - private val nodeToNumber = flowGraph.nodeToNumber + private val nodeToNumber = flowGraph.nodeToNumber - val method: Method = flowGraph.method + val method: Method = flowGraph.method - val gen: Map[StoredNode, mutable.BitSet] = - initGen(method).withDefaultValue(mutable.BitSet()) + val gen: Map[StoredNode, mutable.BitSet] = + initGen(method).withDefaultValue(mutable.BitSet()) - val kill: Map[StoredNode, Set[Definition]] = - initKill(method, gen).withDefaultValue(mutable.BitSet()) + val kill: Map[StoredNode, Set[Definition]] = + initKill(method, gen).withDefaultValue(mutable.BitSet()) - /** For a given flow graph node `n` and set of definitions, apply the transfer function to obtain the updated set of - * definitions, considering `gen(n)` and `kill(n)`. - */ - override def apply(n: StoredNode, x: mutable.BitSet): mutable.BitSet = { - gen(n).union(x.diff(kill(n))) - } + /** For a given flow graph node `n` and set of definitions, apply the transfer function to + * obtain the updated set of definitions, considering `gen(n)` and `kill(n)`. + */ + override def apply(n: StoredNode, x: mutable.BitSet): mutable.BitSet = + gen(n).union(x.diff(kill(n))) - /** Initialize the map `gen`, a map that contains generated definitions for each flow graph node. - */ - def initGen(method: Method): Map[StoredNode, mutable.BitSet] = { + /** Initialize the map `gen`, a map that contains generated definitions for each flow graph + * node. + */ + def initGen(method: Method): Map[StoredNode, mutable.BitSet] = - val defsForParams = method.parameter.l.map { param => - param -> mutable.BitSet(Definition.fromNode(param.asInstanceOf[StoredNode], nodeToNumber)) - } + val defsForParams = method.parameter.l.map { param => + param -> mutable.BitSet(Definition.fromNode( + param.asInstanceOf[StoredNode], + nodeToNumber + )) + } - // We filter out field accesses to ensure that they propagate - // taint unharmed. + // We filter out field accesses to ensure that they propagate + // taint unharmed. - val defsForCalls = method.call - .filterNot(x => isFieldAccess(x.name)) - .l - .map { call => - call -> { - val retVal = List(call) - val args = call.argument - .filter(hasValidGenType) + val defsForCalls = method.call + .filterNot(x => isFieldAccess(x.name)) .l - mutable.BitSet( - (retVal ++ args) - .collect { - case x if nodeToNumber.contains(x) => - Definition.fromNode(x.asInstanceOf[StoredNode], nodeToNumber) - }: _* - ) - } - } - (defsForParams ++ defsForCalls).toMap - } - - /** Restricts the types of nodes that represent definitions. - */ - private def hasValidGenType(node: Expression): Boolean = { - node match { - case _: Call => true - case _: Identifier => true - case _ => false - } - } - - /** Initialize the map `kill`, a map that contains killed definitions for each flow graph node. - * - * All operations in our graph are represented by calls and non-operations such as identifiers or field-identifiers - * have empty gen and kill sets, meaning that they just pass on definitions unaltered. - */ - private def initKill(method: Method, gen: Map[StoredNode, Set[Definition]]): Map[StoredNode, Set[Definition]] = { - - val allIdentifiers: Map[String, List[CfgNode]] = { - val results = mutable.Map.empty[String, List[CfgNode]] - method.ast - .collect { - case identifier: Identifier => - (identifier.name, identifier) - case methodParameterIn: MethodParameterIn => - (methodParameterIn.name, methodParameterIn) - } - .foreach { case (name, node) => - val oldValues = results.getOrElse(name, Nil) - results.put(name, node :: oldValues) + .map { call => + call -> { + val retVal = List(call) + val args = call.argument + .filter(hasValidGenType) + .l + mutable.BitSet( + (retVal ++ args) + .collect { + case x if nodeToNumber.contains(x) => + Definition.fromNode(x.asInstanceOf[StoredNode], nodeToNumber) + }* + ) + } + } + (defsForParams ++ defsForCalls).toMap + end initGen + + /** Restricts the types of nodes that represent definitions. + */ + private def hasValidGenType(node: Expression): Boolean = + node match + case _: Call => true + case _: Identifier => true + case _ => false + + /** Initialize the map `kill`, a map that contains killed definitions for each flow graph node. + * + * All operations in our graph are represented by calls and non-operations such as identifiers + * or field-identifiers have empty gen and kill sets, meaning that they just pass on + * definitions unaltered. + */ + private def initKill( + method: Method, + gen: Map[StoredNode, Set[Definition]] + ): Map[StoredNode, Set[Definition]] = + + val allIdentifiers: Map[String, List[CfgNode]] = + val results = mutable.Map.empty[String, List[CfgNode]] + method.ast + .collect { + case identifier: Identifier => + (identifier.name, identifier) + case methodParameterIn: MethodParameterIn => + (methodParameterIn.name, methodParameterIn) + } + .foreach { case (name, node) => + val oldValues = results.getOrElse(name, Nil) + results.put(name, node :: oldValues) + } + results.toMap + + val allCalls: Map[String, List[Call]] = + method.call.l + .groupBy(_.code) + .withDefaultValue(List.empty[Call]) + + // We filter out field accesses to ensure that they propagate + // taint unharmed. + + method.call + .filterNot(x => isGenericMemberAccessName(x.name)) + .map { call => + call -> killsForGens(gen(call), allIdentifiers, allCalls) + } + .toMap + end initKill + + /** The only way in which a call can kill another definition is by generating a new definition + * for the same variable. Given the set of generated definitions `gens`, we calculate + * definitions of the same variable for each, that is, we calculate kill(call) based on + * gen(call). + */ + private def killsForGens( + genOfCall: Set[Definition], + allIdentifiers: Map[String, List[CfgNode]], + allCalls: Map[String, List[Call]] + ): Set[Definition] = + + def definitionsOfSameVariable(definition: Definition): Set[Definition] = + val definedNodes = flowGraph.numberToNode(definition) match + case param: MethodParameterIn => + allIdentifiers(param.name) + .filter(x => x.id != param.id) + case identifier: Identifier => + val sameIdentifiers = allIdentifiers(identifier.name) + .filter(x => x.id != identifier.id) + + /** Killing an identifier should also kill field accesses on that identifier. + * For example, a reassignment `x = new Box()` should kill any previous calls + * to `x.value`, `x.length()`, etc. + */ + val sameObjects: Iterable[Call] = allCalls.values.flatten + .filter(_.name == Operators.fieldAccess) + .filter(_.ast.isIdentifier.nameExact(identifier.name).nonEmpty) + + sameIdentifiers ++ sameObjects + case call: Call => + allCalls(call.code) + .filter(x => x.id != call.id) + case _ => Set() + definedNodes + // It can happen that the CFG is broken and contains isolated nodes, + // in which case they are not in `nodeToNumber`. Let's filter those. + .collect { + case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) + }.toSet + end definitionsOfSameVariable + + genOfCall.flatMap { definition => + definitionsOfSameVariable(definition) } - results.toMap - } - - val allCalls: Map[String, List[Call]] = - method.call.l - .groupBy(_.code) - .withDefaultValue(List.empty[Call]) - - // We filter out field accesses to ensure that they propagate - // taint unharmed. - - method.call - .filterNot(x => isGenericMemberAccessName(x.name)) - .map { call => - call -> killsForGens(gen(call), allIdentifiers, allCalls) - } - .toMap - } - - /** The only way in which a call can kill another definition is by generating a new definition for the same variable. - * Given the set of generated definitions `gens`, we calculate definitions of the same variable for each, that is, we - * calculate kill(call) based on gen(call). - */ - private def killsForGens( - genOfCall: Set[Definition], - allIdentifiers: Map[String, List[CfgNode]], - allCalls: Map[String, List[Call]] - ): Set[Definition] = { - - def definitionsOfSameVariable(definition: Definition): Set[Definition] = { - val definedNodes = flowGraph.numberToNode(definition) match { - case param: MethodParameterIn => - allIdentifiers(param.name) - .filter(x => x.id != param.id) - case identifier: Identifier => - val sameIdentifiers = allIdentifiers(identifier.name) - .filter(x => x.id != identifier.id) - - /** Killing an identifier should also kill field accesses on that identifier. For example, a reassignment `x = - * new Box()` should kill any previous calls to `x.value`, `x.length()`, etc. - */ - val sameObjects: Iterable[Call] = allCalls.values.flatten - .filter(_.name == Operators.fieldAccess) - .filter(_.ast.isIdentifier.nameExact(identifier.name).nonEmpty) - - sameIdentifiers ++ sameObjects - case call: Call => - allCalls(call.code) - .filter(x => x.id != call.id) - case _ => Set() - } - definedNodes - // It can happen that the CFG is broken and contains isolated nodes, - // in which case they are not in `nodeToNumber`. Let's filter those. - .collect { case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) }.toSet - } - - genOfCall.flatMap { definition => - definitionsOfSameVariable(definition) - } - } - -} - -/** Lone Identifier Optimization: we first determine and store all identifiers that neither refer to a local nor a - * parameter and that appear only once as a call argument and not also in a return statement. For these identifiers, we - * know that they are not used in any other location in the code, and so, we remove them from `gen` sets so that they - * need not be propagated through the entire graph only to determine that they reach the exit node. Instead, when - * creating reaching definition edges, we simply create edges from the identifier to the exit node. + end killsForGens +end ReachingDefTransferFunction + +/** Lone Identifier Optimization: we first determine and store all identifiers that neither refer to + * a local nor a parameter and that appear only once as a call argument and not also in a return + * statement. For these identifiers, we know that they are not used in any other location in the + * code, and so, we remove them from `gen` sets so that they need not be propagated through the + * entire graph only to determine that they reach the exit node. Instead, when creating reaching + * definition edges, we simply create edges from the identifier to the exit node. */ class OptimizedReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) - extends ReachingDefTransferFunction(flowGraph) { - - lazy val loneIdentifiers: Map[Call, List[Definition]] = { - val identifiersInReturns = method._returnViaContainsOut.ast.isIdentifier.name.l - val paramAndLocalNames = method.parameter.name.l ++ method.local.name.l - val callArgPairs = method.call.flatMap { call => - call.argument.isIdentifier - .filterNot(i => paramAndLocalNames.contains(i.name)) - .filterNot(i => identifiersInReturns.contains(i.name)) - .map(arg => (arg.name, call, arg)) - }.l - - callArgPairs - .groupBy(_._1) - .collect { case (_, v) if v.size == 1 => v.map { case (_, call, arg) => (call, arg) }.head } - .toList - .groupBy(_._1) - .map { case (k, v) => - ( - k, - v.filter(x => flowGraph.nodeToNumber.contains(x._2)) - .map(x => Definition.fromNode(x._2, flowGraph.nodeToNumber)) - ) - } - } - - override def initGen(method: Method): Map[StoredNode, mutable.BitSet] = - withoutLoneIdentifiers(super.initGen(method)) - - private def withoutLoneIdentifiers(g: Map[StoredNode, mutable.BitSet]): Map[StoredNode, mutable.BitSet] = { - g.map { case (k, defs) => - k match { - case call: Call if loneIdentifiers.contains(call) => - (call, defs.filterNot(loneIdentifiers(call).contains(_))) - case _ => (k, defs) - } - } - } -} - -class ReachingDefInit(gen: Map[StoredNode, mutable.BitSet]) extends InOutInit[StoredNode, mutable.BitSet] { - override def initIn: Map[StoredNode, mutable.BitSet] = - Map - .empty[StoredNode, mutable.BitSet] - .withDefaultValue(mutable.BitSet()) - - override def initOut: Map[StoredNode, mutable.BitSet] = gen -} + extends ReachingDefTransferFunction(flowGraph): + + lazy val loneIdentifiers: Map[Call, List[Definition]] = + val identifiersInReturns = method._returnViaContainsOut.ast.isIdentifier.name.l + val paramAndLocalNames = method.parameter.name.l ++ method.local.name.l + val callArgPairs = method.call.flatMap { call => + call.argument.isIdentifier + .filterNot(i => paramAndLocalNames.contains(i.name)) + .filterNot(i => identifiersInReturns.contains(i.name)) + .map(arg => (arg.name, call, arg)) + }.l + + callArgPairs + .groupBy(_._1) + .collect { + case (_, v) if v.size == 1 => v.map { case (_, call, arg) => (call, arg) }.head + } + .toList + .groupBy(_._1) + .map { case (k, v) => + ( + k, + v.filter(x => flowGraph.nodeToNumber.contains(x._2)) + .map(x => Definition.fromNode(x._2, flowGraph.nodeToNumber)) + ) + } + end loneIdentifiers + + override def initGen(method: Method): Map[StoredNode, mutable.BitSet] = + withoutLoneIdentifiers(super.initGen(method)) + + private def withoutLoneIdentifiers(g: Map[StoredNode, mutable.BitSet]) + : Map[StoredNode, mutable.BitSet] = + g.map { case (k, defs) => + k match + case call: Call if loneIdentifiers.contains(call) => + (call, defs.filterNot(loneIdentifiers(call).contains(_))) + case _ => (k, defs) + } +end OptimizedReachingDefTransferFunction + +class ReachingDefInit(gen: Map[StoredNode, mutable.BitSet]) + extends InOutInit[StoredNode, mutable.BitSet]: + override def initIn: Map[StoredNode, mutable.BitSet] = + Map + .empty[StoredNode, mutable.BitSet] + .withDefaultValue(mutable.BitSet()) + + override def initOut: Map[StoredNode, mutable.BitSet] = gen diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala index c4cb793e..1e2de5dd 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala @@ -1,5 +1,4 @@ package io.appthreat.dataflowengineoss.passes -package object reachingdef { - type Definition = Int -} +package object reachingdef: + type Definition = Int diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala index 694999a7..a6cec096 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala @@ -4,51 +4,50 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.accesspath.* import io.shiftleft.semanticcpg.language.toCallMethods import io.shiftleft.semanticcpg.accesspath.{ - AccessElement, - AccessPath, - Elements, - TrackedBase, - TrackedReturnValue, - TrackedUnknown + AccessElement, + AccessPath, + Elements, + TrackedBase, + TrackedReturnValue, + TrackedUnknown } import io.shiftleft.semanticcpg.language.AccessPathHandling import io.shiftleft.semanticcpg.utils.MemberAccess import org.slf4j.LoggerFactory -object AccessPathUsage { +object AccessPathUsage: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - def toTrackedBaseAndAccessPathSimple(node: StoredNode): (TrackedBase, AccessPath) = { - val (base, revPath) = toTrackedBaseAndAccessPathInternal(node) - (base, AccessPath.apply(Elements.normalized(revPath.reverse), Nil)) - } + def toTrackedBaseAndAccessPathSimple(node: StoredNode): (TrackedBase, AccessPath) = + val (base, revPath) = toTrackedBaseAndAccessPathInternal(node) + (base, AccessPath.apply(Elements.normalized(revPath.reverse), Nil)) - private def toTrackedBaseAndAccessPathInternal(node: StoredNode): (TrackedBase, List[AccessElement]) = { - val result = AccessPathHandling.leafToTrackedBaseAndAccessPathInternal(node) - if (result.isDefined) { - result.get - } else { - node match { + private def toTrackedBaseAndAccessPathInternal(node: StoredNode) + : (TrackedBase, List[AccessElement]) = + val result = AccessPathHandling.leafToTrackedBaseAndAccessPathInternal(node) + if result.isDefined then + result.get + else + node match - case block: Block => - AccessPathHandling - .lastExpressionInBlock(block) - .map { toTrackedBaseAndAccessPathInternal } - .getOrElse((TrackedUnknown, Nil)) - case call: Call if !MemberAccess.isGenericMemberAccessName(call.name) => (TrackedReturnValue(call), Nil) + case block: Block => + AccessPathHandling + .lastExpressionInBlock(block) + .map { toTrackedBaseAndAccessPathInternal } + .getOrElse((TrackedUnknown, Nil)) + case call: Call if !MemberAccess.isGenericMemberAccessName(call.name) => + (TrackedReturnValue(call), Nil) - case memberAccess: Call => - // assume: MemberAccess.isGenericMemberAccessName(call.name) - val argOne = memberAccess.argumentOption(1) - if (argOne.isEmpty) { - logger.warn(s"Missing first argument on call ${memberAccess.code}.") - return (TrackedUnknown, Nil) - } - val (base, tail) = toTrackedBaseAndAccessPathInternal(argOne.get) - val path = AccessPathHandling.memberAccessToPath(memberAccess, tail) - (base, path) - } - } - } -} + case memberAccess: Call => + // assume: MemberAccess.isGenericMemberAccessName(call.name) + val argOne = memberAccess.argumentOption(1) + if argOne.isEmpty then + logger.warn(s"Missing first argument on call ${memberAccess.code}.") + return (TrackedUnknown, Nil) + val (base, tail) = toTrackedBaseAndAccessPathInternal(argOne.get) + val path = AccessPathHandling.memberAccessToPath(memberAccess, tail) + (base, path) + end if + end toTrackedBaseAndAccessPathInternal +end AccessPathUsage diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala index 067223c4..00c42971 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala @@ -16,295 +16,291 @@ import scala.collection.mutable import scala.jdk.CollectionConverters.* import scala.util.{Failure, Success, Try} -/** The data flow engine allows determining paths to a set of sinks from a set of sources. To this end, it solves tasks - * in parallel, creating and submitting new tasks upon completion of tasks. This class deals only with task scheduling, - * while the creation of new tasks from existing tasks is handled by the class `TaskCreator`, and solving of tasks is - * taken care of by the `TaskSolver`. +/** The data flow engine allows determining paths to a set of sinks from a set of sources. To this + * end, it solves tasks in parallel, creating and submitting new tasks upon completion of tasks. + * This class deals only with task scheduling, while the creation of new tasks from existing tasks + * is handled by the class `TaskCreator`, and solving of tasks is taken care of by the + * `TaskSolver`. */ -class Engine(context: EngineContext) { - - import Engine.* - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - private val executorService: ExecutorService = - Executors.newWorkStealingPool(2) - private val completionService = - new ExecutorCompletionService[TaskSummary](executorService) - - /** All results of tasks are accumulated in this table. At the end of the analysis, we extract results from the table - * and return them. - */ - private val mainResultTable: mutable.Map[TaskFingerprint, List[TableEntry]] = mutable.Map() - private var numberOfTasksRunning: Int = 0 - private val started: mutable.HashSet[TaskFingerprint] = mutable.HashSet[TaskFingerprint]() - private val held: mutable.Buffer[ReachableByTask] = mutable.Buffer() - - /** Determine flows from sources to sinks by exploring the graph backwards from sinks to sources. Returns the list of - * results along with a ResultTable, a cache of known paths created during the analysis. - */ - def backwards(sinks: List[CfgNode], sources: List[CfgNode]): List[TableEntry] = { - if (sources.isEmpty) { - logger.debug("Attempting to determine flows from empty list of sources.") - } - if (sinks.isEmpty) { - logger.debug("Attempting to determine flows to empty list of sinks.") - } - reset() - val sourcesSet = sources.toSet - val tasks = createOneTaskPerSink(sinks) - solveTasks(tasks, sourcesSet, sinks) - } - - private def reset(): Unit = { - mainResultTable.clear() - numberOfTasksRunning = 0 - started.clear() - held.clear() - } - - private def createOneTaskPerSink(sinks: List[CfgNode]) = { - sinks.map(sink => ReachableByTask(List(TaskFingerprint(sink, List(), 0)), Vector())) - } - - /** Submit tasks to a worker pool, solving them in parallel. Upon receiving results for a task, new tasks are - * submitted accordingly. Once no more tasks can be created, the list of results is returned. - */ - private def solveTasks( - tasks: List[ReachableByTask], - sources: Set[CfgNode], - sinks: List[CfgNode] - ): List[TableEntry] = { - - /** Solving a task produces a list of summaries. The following method is called for each of these summaries. It - * submits new tasks and adds results to the result table. +class Engine(context: EngineContext): + + import Engine.* + + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + private val executorService: ExecutorService = + Executors.newWorkStealingPool(2) + private val completionService = + new ExecutorCompletionService[TaskSummary](executorService) + + /** All results of tasks are accumulated in this table. At the end of the analysis, we extract + * results from the table and return them. */ - def handleSummary(taskSummary: TaskSummary): Unit = { - val newTasks = taskSummary.followupTasks - submitTasks(newTasks, sources) - val newResults = taskSummary.tableEntries - addEntriesToMainTable(newResults) - } - - def addEntriesToMainTable(entries: Vector[(TaskFingerprint, TableEntry)]): Unit = { - entries.groupBy(_._1).foreach { case (fingerprint, entryList) => - val entries = entryList.map(_._2).toList - mainResultTable.updateWith(fingerprint) { - case Some(list) => Some(list ++ entries) - case None => Some(entries) - } - } - } - - def runUntilAllTasksAreSolved(): Unit = { - while (numberOfTasksRunning > 0) { - Try { - completionService.take.get - } match { - case Success(resultsOfTask) => - numberOfTasksRunning -= 1 - handleSummary(resultsOfTask) - case Failure(_) => - numberOfTasksRunning -= 1 + private val mainResultTable: mutable.Map[TaskFingerprint, List[TableEntry]] = mutable.Map() + private var numberOfTasksRunning: Int = 0 + private val started: mutable.HashSet[TaskFingerprint] = mutable.HashSet[TaskFingerprint]() + private val held: mutable.Buffer[ReachableByTask] = mutable.Buffer() + + /** Determine flows from sources to sinks by exploring the graph backwards from sinks to + * sources. Returns the list of results along with a ResultTable, a cache of known paths + * created during the analysis. + */ + def backwards(sinks: List[CfgNode], sources: List[CfgNode]): List[TableEntry] = + if sources.isEmpty then + logger.debug("Attempting to determine flows from empty list of sources.") + if sinks.isEmpty then + logger.debug("Attempting to determine flows to empty list of sinks.") + reset() + val sourcesSet = sources.toSet + val tasks = createOneTaskPerSink(sinks) + solveTasks(tasks, sourcesSet, sinks) + + private def reset(): Unit = + mainResultTable.clear() + numberOfTasksRunning = 0 + started.clear() + held.clear() + + private def createOneTaskPerSink(sinks: List[CfgNode]) = + sinks.map(sink => ReachableByTask(List(TaskFingerprint(sink, List(), 0)), Vector())) + + /** Submit tasks to a worker pool, solving them in parallel. Upon receiving results for a task, + * new tasks are submitted accordingly. Once no more tasks can be created, the list of results + * is returned. + */ + private def solveTasks( + tasks: List[ReachableByTask], + sources: Set[CfgNode], + sinks: List[CfgNode] + ): List[TableEntry] = + + /** Solving a task produces a list of summaries. The following method is called for each of + * these summaries. It submits new tasks and adds results to the result table. + */ + def handleSummary(taskSummary: TaskSummary): Unit = + val newTasks = taskSummary.followupTasks + submitTasks(newTasks, sources) + val newResults = taskSummary.tableEntries + addEntriesToMainTable(newResults) + + def addEntriesToMainTable(entries: Vector[(TaskFingerprint, TableEntry)]): Unit = + entries.groupBy(_._1).foreach { case (fingerprint, entryList) => + val entries = entryList.map(_._2).toList + mainResultTable.updateWith(fingerprint) { + case Some(list) => Some(list ++ entries) + case None => Some(entries) + } + } + + def runUntilAllTasksAreSolved(): Unit = + while numberOfTasksRunning > 0 do + Try { + completionService.take.get + } match + case Success(resultsOfTask) => + numberOfTasksRunning -= 1 + handleSummary(resultsOfTask) + case Failure(_) => + numberOfTasksRunning -= 1 + + submitTasks(tasks.toVector, sources) + val startTimeSec: Long = System.currentTimeMillis / 1000 + runUntilAllTasksAreSolved() + val taskFinishTimeSec: Long = System.currentTimeMillis / 1000 + logger.debug( + "Time measurement -----> Task processing done in " + + (taskFinishTimeSec - startTimeSec) + " seconds" + ) + new HeldTaskCompletion(held.toList, mainResultTable).completeHeldTasks() + val dedupResult = deduplicateFinal(extractResultsFromTable(sinks)) + val allDoneTimeSec: Long = System.currentTimeMillis / 1000 + + logger.debug( + "Time measurement -----> Task processing: " + + (taskFinishTimeSec - startTimeSec) + " seconds" + + ", Deduplication: " + (allDoneTimeSec - taskFinishTimeSec) + + ", Deduped results size: " + dedupResult.length + ) + dedupResult + end solveTasks + + private def submitTasks(tasks: Vector[ReachableByTask], sources: Set[CfgNode]): Unit = + tasks.foreach { task => + if started.contains(task.fingerprint) then + held ++= Vector(task) + else + started.add(task.fingerprint) + numberOfTasksRunning += 1 + completionService.submit(new TaskSolver(task, context, sources)) } - } - } - - submitTasks(tasks.toVector, sources) - val startTimeSec: Long = System.currentTimeMillis / 1000 - runUntilAllTasksAreSolved() - val taskFinishTimeSec: Long = System.currentTimeMillis / 1000 - logger.debug( - "Time measurement -----> Task processing done in " + - (taskFinishTimeSec - startTimeSec) + " seconds" - ) - new HeldTaskCompletion(held.toList, mainResultTable).completeHeldTasks() - val dedupResult = deduplicateFinal(extractResultsFromTable(sinks)) - val allDoneTimeSec: Long = System.currentTimeMillis / 1000 - - logger.debug( - "Time measurement -----> Task processing: " + - (taskFinishTimeSec - startTimeSec) + " seconds" + - ", Deduplication: " + (allDoneTimeSec - taskFinishTimeSec) + - ", Deduped results size: " + dedupResult.length - ) - dedupResult - } - - private def submitTasks(tasks: Vector[ReachableByTask], sources: Set[CfgNode]): Unit = { - tasks.foreach { task => - if (started.contains(task.fingerprint)) { - held ++= Vector(task) - } else { - started.add(task.fingerprint) - numberOfTasksRunning += 1 - completionService.submit(new TaskSolver(task, context, sources)) - } - } - } - - private def extractResultsFromTable(sinks: List[CfgNode]): List[TableEntry] = { - sinks.flatMap { sink => - mainResultTable.get(TaskFingerprint(sink, List(), 0)) match { - case Some(results) => results - case _ => Vector() - } - } - } - - private def deduplicateFinal(list: List[TableEntry]): List[TableEntry] = { - list - .groupBy { result => - val head = result.path.head.node - val last = result.path.last.node - (head, last) - } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)) - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match { - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - }).map(_._2) - - if (withMaxLength.length == 1) { - withMaxLength.head - } else { - withMaxLength.minBy { x => - x.path - .map(x => (x.node.id, x.callSiteStack.map(_.id), x.visible, x.isOutputArg, x.outEdgeLabel).toString) - .mkString("-") - } + + private def extractResultsFromTable(sinks: List[CfgNode]): List[TableEntry] = + sinks.flatMap { sink => + mainResultTable.get(TaskFingerprint(sink, List(), 0)) match + case Some(results) => results + case _ => Vector() } - } - .toList - } - - /** This must be called when one is done using the engine. - */ - def shutdown(): Unit = { - executorService.shutdown() - } - -} - -object Engine { - - /** Traverse from a node to incoming DDG nodes, taking into account semantics. This method is exposed via the `ddgIn` - * step, but is also called by the engine internally by the `TaskSolver`. - * - * @param curNode - * the node to expand - * @param path - * the path that has been expanded to reach the `curNode` - */ - def expandIn(curNode: CfgNode, path: Vector[PathElement], callSiteStack: List[Call] = List())(implicit - semantics: Semantics - ): Vector[PathElement] = { - ddgInE(curNode, path, callSiteStack).flatMap(x => elemForEdge(x, callSiteStack)) - } - - private def elemForEdge(e: Edge, callSiteStack: List[Call] = List())(implicit - semantics: Semantics - ): Option[PathElement] = { - val curNode = e.inNode().asInstanceOf[CfgNode] - val parNode = e.outNode().asInstanceOf[CfgNode] - val outLabel = Some(e.property(Properties.VARIABLE)).getOrElse("") - - if (!EdgeValidator.isValidEdge(curNode, parNode)) { - return None - } - - curNode match { - case childNode: Expression => - parNode match { - case parentNode: Expression => - val parentNodeCall = parentNode.inCall.l - val sameCallSite = parentNode.inCall.l == childNode.start.inCall.l - val visible = if (sameCallSite) { - val semanticExists = parentNode.semanticsForCallByArg.nonEmpty - val internalMethodsForCall = parentNodeCall.flatMap(methodsForCall).internal - (semanticExists && parentNode.isDefined) || internalMethodsForCall.isEmpty - } else { - parentNode.isDefined + + private def deduplicateFinal(list: List[TableEntry]): List[TableEntry] = + list + .groupBy { result => + val head = result.path.head.node + val last = result.path.last.node + (head, last) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)) + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) + + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + } + .toList + + /** This must be called when one is done using the engine. + */ + def shutdown(): Unit = + executorService.shutdown() +end Engine + +object Engine: + + /** Traverse from a node to incoming DDG nodes, taking into account semantics. This method is + * exposed via the `ddgIn` step, but is also called by the engine internally by the + * `TaskSolver`. + * + * @param curNode + * the node to expand + * @param path + * the path that has been expanded to reach the `curNode` + */ + def expandIn( + curNode: CfgNode, + path: Vector[PathElement], + callSiteStack: List[Call] = List() + )(implicit semantics: Semantics): Vector[PathElement] = + ddgInE(curNode, path, callSiteStack).flatMap(x => elemForEdge(x, callSiteStack)) + + private def elemForEdge(e: Edge, callSiteStack: List[Call] = List())(implicit + semantics: Semantics + ): Option[PathElement] = + val curNode = e.inNode().asInstanceOf[CfgNode] + val parNode = e.outNode().asInstanceOf[CfgNode] + val outLabel = Some(e.property(Properties.VARIABLE)).getOrElse("") + + if !EdgeValidator.isValidEdge(curNode, parNode) then + return None + + curNode match + case childNode: Expression => + parNode match + case parentNode: Expression => + val parentNodeCall = parentNode.inCall.l + val sameCallSite = parentNode.inCall.l == childNode.start.inCall.l + val visible = if sameCallSite then + val semanticExists = parentNode.semanticsForCallByArg.nonEmpty + val internalMethodsForCall = + parentNodeCall.flatMap(methodsForCall).internal + (semanticExists && parentNode.isDefined) || internalMethodsForCall.isEmpty + else + parentNode.isDefined + val isOutputArg = isOutputArgOfInternalMethod(parentNode) + Some(PathElement( + parentNode, + callSiteStack, + visible, + isOutputArg, + outEdgeLabel = outLabel + )) + case parentNode if parentNode != null => + Some(PathElement(parentNode, callSiteStack, outEdgeLabel = outLabel)) + case null => + None + case _ => + Some(PathElement(parNode, callSiteStack, outEdgeLabel = outLabel)) + end match + end elemForEdge + + def isOutputArgOfInternalMethod(arg: Expression)(implicit semantics: Semantics): Boolean = + arg.inCall.l match + case List(call) => + methodsForCall(call).internal.isNotStub.nonEmpty && semanticsForCall(call).isEmpty + case _ => + false + + /** For a given node `node`, return all incoming reaching definition edges, unless the source + * node is (a) a METHOD node, (b) already present on `path`, or (c) a CALL node to a method + * where the semantic indicates that taint is propagated to it. + */ + private def ddgInE( + node: CfgNode, + path: Vector[PathElement], + callSiteStack: List[Call] = List() + ): Vector[Edge] = + node + .inE(EdgeTypes.REACHING_DEF) + .asScala + .filter { e => + e.outNode() match + case srcNode: CfgNode => + !srcNode.isInstanceOf[Method] && !path + .map(x => x.node) + .contains(srcNode) + case _ => false } - val isOutputArg = isOutputArgOfInternalMethod(parentNode) - Some(PathElement(parentNode, callSiteStack, visible, isOutputArg, outEdgeLabel = outLabel)) - case parentNode if parentNode != null => - Some(PathElement(parentNode, callSiteStack, outEdgeLabel = outLabel)) - case null => - None + .toVector + + def argToOutputParams(arg: Expression): Iterator[MethodParameterOut] = + argToMethods(arg).parameter + .index(arg.argumentIndex) + .asOutput + + def argToMethods(arg: Expression): List[Method] = + arg.inCall.l.flatMap { call => + methodsForCall(call) } - case _ => - Some(PathElement(parNode, callSiteStack, outEdgeLabel = outLabel)) - } - } - - def isOutputArgOfInternalMethod(arg: Expression)(implicit semantics: Semantics): Boolean = { - arg.inCall.l match { - case List(call) => - methodsForCall(call).internal.isNotStub.nonEmpty && semanticsForCall(call).isEmpty - case _ => - false - } - } - - /** For a given node `node`, return all incoming reaching definition edges, unless the source node is (a) a METHOD - * node, (b) already present on `path`, or (c) a CALL node to a method where the semantic indicates that taint is - * propagated to it. - */ - private def ddgInE(node: CfgNode, path: Vector[PathElement], callSiteStack: List[Call] = List()): Vector[Edge] = { - node - .inE(EdgeTypes.REACHING_DEF) - .asScala - .filter { e => - e.outNode() match { - case srcNode: CfgNode => - !srcNode.isInstanceOf[Method] && !path - .map(x => x.node) - .contains(srcNode) - case _ => false + + def methodsForCall(call: Call): List[Method] = + NoResolve.getCalledMethods(call).toList + + def isCallToInternalMethod(call: Call): Boolean = + methodsForCall(call).internal.nonEmpty + def isCallToInternalMethodWithoutSemantic(call: Call)(implicit semantics: Semantics): Boolean = + isCallToInternalMethod(call) && semanticsForCall(call).isEmpty + + def semanticsForCall(call: Call)(implicit semantics: Semantics): List[FlowSemantic] = + Engine.methodsForCall(call).flatMap { method => + semantics.forMethod(method.fullName) } - } - .toVector - } - - def argToOutputParams(arg: Expression): Iterator[MethodParameterOut] = { - argToMethods(arg).parameter - .index(arg.argumentIndex) - .asOutput - } - - def argToMethods(arg: Expression): List[Method] = { - arg.inCall.l.flatMap { call => - methodsForCall(call) - } - } - - def methodsForCall(call: Call): List[Method] = { - NoResolve.getCalledMethods(call).toList - } - - def isCallToInternalMethod(call: Call): Boolean = { - methodsForCall(call).internal.nonEmpty - } - def isCallToInternalMethodWithoutSemantic(call: Call)(implicit semantics: Semantics): Boolean = { - isCallToInternalMethod(call) && semanticsForCall(call).isEmpty - } - - def semanticsForCall(call: Call)(implicit semantics: Semantics): List[FlowSemantic] = { - Engine.methodsForCall(call).flatMap { method => - semantics.forMethod(method.fullName) - } - } - -} +end Engine /** The execution context for the data flow engine. * @param semantics - * pre-determined semantic models for method calls e.g., logical operators, operators for common data structures. + * pre-determined semantic models for method calls e.g., logical operators, operators for common + * data structures. * @param config * additional configurations for the data flow engine. */ -case class EngineContext(semantics: Semantics = DefaultSemantics(), config: EngineConfig = EngineConfig()) +case class EngineContext( + semantics: Semantics = DefaultSemantics(), + config: EngineConfig = EngineConfig() +) /** Various configurations for the data flow engine. * @param maxCallDepth @@ -328,34 +324,33 @@ case class EngineConfig( /** Tracks various performance characteristics of the query engine. */ -object QueryEngineStatistics extends Enumeration { +object QueryEngineStatistics extends Enumeration: - type QueryEngineStatistic = Value + type QueryEngineStatistic = Value - val PATH_CACHE_HITS, PATH_CACHE_MISSES = Value + val PATH_CACHE_HITS, PATH_CACHE_MISSES = Value - private val statistics = new ConcurrentHashMap[QueryEngineStatistic, Long]() + private val statistics = new ConcurrentHashMap[QueryEngineStatistic, Long]() - reset() - - /** Adds the given value to the associated value to the given [[QueryEngineStatistics]] key. - * @param key - * the key associated with the value to transform. - * @param value - * the value to add to the statistic. Can be negative. - */ - def incrementBy(key: QueryEngineStatistic, value: Long): Unit = - statistics.put(key, statistics.getOrDefault(key, 0L) + value) + reset() - /** The results of the measured statistics. - * @return - * a map of each [[QueryEngineStatistic]] and the associated value measurement. - */ - def results(): Map[QueryEngineStatistic, Long] = statistics.asScala.toMap + /** Adds the given value to the associated value to the given [[QueryEngineStatistics]] key. + * @param key + * the key associated with the value to transform. + * @param value + * the value to add to the statistic. Can be negative. + */ + def incrementBy(key: QueryEngineStatistic, value: Long): Unit = + statistics.put(key, statistics.getOrDefault(key, 0L) + value) - /** Sets all the tracked values back to 0. - */ - def reset(): Unit = - QueryEngineStatistics.values.map((_, 0L)).foreach { case (v, t) => statistics.put(v, t) } + /** The results of the measured statistics. + * @return + * a map of each [[QueryEngineStatistic]] and the associated value measurement. + */ + def results(): Map[QueryEngineStatistic, Long] = statistics.asScala.toMap -} + /** Sets all the tracked values back to 0. + */ + def reset(): Unit = + QueryEngineStatistics.values.map((_, 0L)).foreach { case (v, t) => statistics.put(v, t) } +end QueryEngineStatistics diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala index 6bf2b1f2..0aa31823 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala @@ -1,171 +1,181 @@ package io.appthreat.dataflowengineoss.queryengine import scala.collection.mutable -import scala.collection.parallel.CollectionConverters._ +import scala.collection.parallel.CollectionConverters.* /** Complete held tasks using the result table. The result table is modified in the process. * * Results obtained when completing a held task depend on the following: * - * (a) the `initialPath` of the held task (path from the node where the task was held down to a sink) + * (a) the `initialPath` of the held task (path from the node where the task was held down to a + * sink) * * (b) the entries in the table for `heldTask.fingerprint`. * - * Upon completing a task, new results are stored in the table for each task of its `taskStack`. This means that we may - * not end up with all results when first completing a task because another task needs to be completed first so that - * all results for `heldTask.fingerprint` are available. We address this problem by computing results in a loop until - * no more changes can be observed. + * Upon completing a task, new results are stored in the table for each task of its `taskStack`. + * This means that we may not end up with all results when first completing a task because another + * task needs to be completed first so that all results for `heldTask.fingerprint` are available. + * We address this problem by computing results in a loop until no more changes can be observed. */ class HeldTaskCompletion( heldTasks: List[ReachableByTask], resultTable: mutable.Map[TaskFingerprint, List[TableEntry]] -) { - - /** Add results produced by held task until no more change can be observed. - * - * We use the following simple algorithm (that can possibly be optimized in the future): - * - * For each held `task`, we keep a Boolean `changed(task)`, which indicates whether new results for the `task` were - * produced. We initialize the Booleans to be true. Computation is terminated when all Booleans are false, that is, - * when no more changes in the result table can be observed. - * - * If we do detect a change, we determine all tasks for which changed results exist and recompute their results. We - * compare the results with those produced previously (stored in `resultsProducedByTask`). If any new results were - * created, `changed` is set to true for the result's table entry and `resultsProductByTask` is updated. - */ - def completeHeldTasks(): Unit = { - - deduplicateResultTable() - val toProcess = - heldTasks.distinct.sortBy(x => - (x.fingerprint.sink.id, x.fingerprint.callSiteStack.map(_.id).toString, x.callDepth) - ) - var resultsProducedByTask: Map[ReachableByTask, Set[(TaskFingerprint, TableEntry)]] = Map() - - def allChanged = toProcess.map { task => task.fingerprint -> true }.toMap - def noneChanged = toProcess.map { t => t.fingerprint -> false }.toMap - - var changed: Map[TaskFingerprint, Boolean] = allChanged - - while (changed.values.toList.contains(true)) { - val taskResultsPairs = toProcess - .filter(t => changed(t.fingerprint)) - .par - .map { t => - val resultsForTask = resultsForHeldTask(t).toSet - val newResults = resultsForTask -- resultsProducedByTask.getOrElse(t, Set()) - (t, resultsForTask, newResults) - } - .filter { case (_, _, newResults) => newResults.nonEmpty } - .seq - - changed = noneChanged - taskResultsPairs.foreach { case (t, resultsForTask, newResults) => - addCompletedTasksToMainTable(newResults.toList) - newResults.foreach { case (fingerprint, _) => - changed += fingerprint -> true +): + + /** Add results produced by held task until no more change can be observed. + * + * We use the following simple algorithm (that can possibly be optimized in the future): + * + * For each held `task`, we keep a Boolean `changed(task)`, which indicates whether new results + * for the `task` were produced. We initialize the Booleans to be true. Computation is + * terminated when all Booleans are false, that is, when no more changes in the result table + * can be observed. + * + * If we do detect a change, we determine all tasks for which changed results exist and + * recompute their results. We compare the results with those produced previously (stored in + * `resultsProducedByTask`). If any new results were created, `changed` is set to true for the + * result's table entry and `resultsProductByTask` is updated. + */ + def completeHeldTasks(): Unit = + + deduplicateResultTable() + val toProcess = + heldTasks.distinct.sortBy(x => + (x.fingerprint.sink.id, x.fingerprint.callSiteStack.map(_.id).toString, x.callDepth) + ) + var resultsProducedByTask: Map[ReachableByTask, Set[(TaskFingerprint, TableEntry)]] = Map() + + def allChanged = toProcess.map { task => task.fingerprint -> true }.toMap + def noneChanged = toProcess.map { t => t.fingerprint -> false }.toMap + + var changed: Map[TaskFingerprint, Boolean] = allChanged + + while changed.values.toList.contains(true) do + val taskResultsPairs = toProcess + .filter(t => changed(t.fingerprint)) + .par + .map { t => + val resultsForTask = resultsForHeldTask(t).toSet + val newResults = resultsForTask -- resultsProducedByTask.getOrElse(t, Set()) + (t, resultsForTask, newResults) + } + .filter { case (_, _, newResults) => newResults.nonEmpty } + .seq + + changed = noneChanged + taskResultsPairs.foreach { case (t, resultsForTask, newResults) => + addCompletedTasksToMainTable(newResults.toList) + newResults.foreach { case (fingerprint, _) => + changed += fingerprint -> true + } + resultsProducedByTask += (t -> resultsForTask) + } + end while + deduplicateResultTable() + end completeHeldTasks + + /** In essence, completing a held task simply means appending the path stored in the held task + * to all results that are available for the held task in the table. In practice, we create one + * result for each task of the parent task's `taskStack`, so that we do not only get a new + * result for the sink, but one for each of the parent nodes on the way. + */ + private def resultsForHeldTask(heldTask: ReachableByTask): List[(TaskFingerprint, TableEntry)] = + // Create a flat list of results by computing results for each + // table entry and appending them. + resultTable.get(heldTask.fingerprint) match + case Some(results) => + results + .flatMap { r => + createResultsForHeldTaskAndTableResult(heldTask, r) + } + case None => List() + + /** This method creates a list of results from a held task and a table entry by appending paths + * of the held task to the path stored in the held task (`initialPath`) up to each of its + * parent tasks. + * + * A possible optimization here is to store computed slices in a lazily populated table and + * attempt to look them up. + */ + private def createResultsForHeldTaskAndTableResult( + heldTask: ReachableByTask, + result: TableEntry + ): List[(TaskFingerprint, TableEntry)] = + val parentTasks = heldTask.taskStack.dropRight(1) + val initialPath = heldTask.initialPath + parentTasks + .map { parentTask => + val stopIndex = initialPath + .map(x => (x.node, x.callSiteStack)) + .indexOf((parentTask.sink, parentTask.callSiteStack)) + 1 + val initialPathOnlyUpToSink = initialPath.slice(0, stopIndex) + val newPath = result.path ++ initialPathOnlyUpToSink + (parentTask, TableEntry(newPath)) + } + .filter { case (_, tableEntry) => containsCycle(tableEntry) } + end createResultsForHeldTaskAndTableResult + + private def containsCycle(tableEntry: TableEntry): Boolean = + val pathSeq = + tableEntry.path.map(x => (x.node, x.callSiteStack, x.isOutputArg, x.outEdgeLabel)) + pathSeq.distinct.size == pathSeq.size + + private def addCompletedTasksToMainTable(results: List[(TaskFingerprint, TableEntry)]): Unit = + results.groupBy(_._1).foreach { case (fingerprint, resultList) => + val entries = resultList.map(_._2) + val old = resultTable.getOrElse(fingerprint, Vector()).toList + resultTable.put(fingerprint, deduplicateTableEntries(old ++ entries)) } - resultsProducedByTask += (t -> resultsForTask) - } - } - deduplicateResultTable() - } - - /** In essence, completing a held task simply means appending the path stored in the held task to all results that are - * available for the held task in the table. In practice, we create one result for each task of the parent task's - * `taskStack`, so that we do not only get a new result for the sink, but one for each of the parent nodes on the - * way. - */ - private def resultsForHeldTask(heldTask: ReachableByTask): List[(TaskFingerprint, TableEntry)] = { - // Create a flat list of results by computing results for each - // table entry and appending them. - resultTable.get(heldTask.fingerprint) match { - case Some(results) => - results - .flatMap { r => - createResultsForHeldTaskAndTableResult(heldTask, r) - } - case None => List() - } - } - - /** This method creates a list of results from a held task and a table entry by appending paths of the held task to - * the path stored in the held task (`initialPath`) up to each of its parent tasks. - * - * A possible optimization here is to store computed slices in a lazily populated table and attempt to look them up. - */ - private def createResultsForHeldTaskAndTableResult( - heldTask: ReachableByTask, - result: TableEntry - ): List[(TaskFingerprint, TableEntry)] = { - val parentTasks = heldTask.taskStack.dropRight(1) - val initialPath = heldTask.initialPath - parentTasks - .map { parentTask => - val stopIndex = initialPath - .map(x => (x.node, x.callSiteStack)) - .indexOf((parentTask.sink, parentTask.callSiteStack)) + 1 - val initialPathOnlyUpToSink = initialPath.slice(0, stopIndex) - val newPath = result.path ++ initialPathOnlyUpToSink - (parentTask, TableEntry(newPath)) - } - .filter { case (_, tableEntry) => containsCycle(tableEntry) } - } - - private def containsCycle(tableEntry: TableEntry): Boolean = { - val pathSeq = tableEntry.path.map(x => (x.node, x.callSiteStack, x.isOutputArg, x.outEdgeLabel)) - pathSeq.distinct.size == pathSeq.size - } - - private def addCompletedTasksToMainTable(results: List[(TaskFingerprint, TableEntry)]): Unit = { - results.groupBy(_._1).foreach { case (fingerprint, resultList) => - val entries = resultList.map(_._2) - val old = resultTable.getOrElse(fingerprint, Vector()).toList - resultTable.put(fingerprint, deduplicateTableEntries(old ++ entries)) - } - } - - private def deduplicateResultTable(): Unit = { - resultTable.keys.foreach { key => - val results = resultTable(key) - resultTable.put(key, deduplicateTableEntries(results)) - } - } - - /** This method deduplicates the list of entries stored in a table cell. - * - * We treat entries as the same if their start and end point are the same. Points are given by nodes in the graph, - * the `callSiteStack` and the `isOutputArg` flag. - * - * For a group of flows that we treat as the same, we select the flow with the maximum length. If there are multiple - * flows with maximum length, then we compute a string representation of the flows - taking into account all fields - * - and select the flow with maximum length that is smallest in terms of this string representation. - */ - private def deduplicateTableEntries(list: List[TableEntry]): List[TableEntry] = { - list - .groupBy { result => - val head = result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - val last = result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - (head, last) - } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)) - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match { - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - }).map(_._2) - if (withMaxLength.length == 1) { - withMaxLength.head - } else { - withMaxLength.minBy { x => - x.path - .map(x => (x.node.id, x.callSiteStack.map(_.id), x.visible, x.isOutputArg, x.outEdgeLabel).toString) - .mkString("-") - } + private def deduplicateResultTable(): Unit = + resultTable.keys.foreach { key => + val results = resultTable(key) + resultTable.put(key, deduplicateTableEntries(results)) } - } - .toList - } -} + /** This method deduplicates the list of entries stored in a table cell. + * + * We treat entries as the same if their start and end point are the same. Points are given by + * nodes in the graph, the `callSiteStack` and the `isOutputArg` flag. + * + * For a group of flows that we treat as the same, we select the flow with the maximum length. + * If there are multiple flows with maximum length, then we compute a string representation of + * the flows - taking into account all fields + * - and select the flow with maximum length that is smallest in terms of this string + * representation. + */ + private def deduplicateTableEntries(list: List[TableEntry]): List[TableEntry] = + list + .groupBy { result => + val head = + result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + val last = + result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + (head, last) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)) + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) + + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + } + .toList +end HeldTaskCompletion diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala index cc1c74dd..2b87b7e8 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala @@ -4,8 +4,8 @@ import io.appthreat.dataflowengineoss.globalFromLiteral import io.appthreat.x2cpg.Defines import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Operators -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.allAssignmentTypes import io.shiftleft.semanticcpg.utils.MemberAccess.isFieldAccess import org.slf4j.LoggerFactory @@ -15,205 +15,212 @@ import scala.util.{Failure, Success, Try} case class StartingPointWithSource(startingPoint: CfgNode, source: StoredNode) -object SourcesToStartingPoints { +object SourcesToStartingPoints: - private val log = LoggerFactory.getLogger(SourcesToStartingPoints.getClass) + private val log = LoggerFactory.getLogger(SourcesToStartingPoints.getClass) - def sourceTravsToStartingPoints[NodeType](sourceTravs: IterableOnce[NodeType]*): List[StartingPointWithSource] = { - val fjp = new ForkJoinPool(Runtime.getRuntime.availableProcessors() / 2) - try { - fjp.invoke(new SourceTravsToStartingPointsTask(sourceTravs: _*)).distinct - } catch { - case e: RejectedExecutionException => - log.error("Unable to execute 'SourceTravsToStartingPoints` task", e); List() - } finally { - fjp.shutdown() - } - } - -} + def sourceTravsToStartingPoints[NodeType](sourceTravs: IterableOnce[NodeType]*) + : List[StartingPointWithSource] = + val fjp = new ForkJoinPool(Runtime.getRuntime.availableProcessors() / 2) + try + fjp.invoke(new SourceTravsToStartingPointsTask(sourceTravs*)).distinct + catch + case e: RejectedExecutionException => + log.error("Unable to execute 'SourceTravsToStartingPoints` task", e); List() + finally + fjp.shutdown() class SourceTravsToStartingPointsTask[NodeType](sourceTravs: IterableOnce[NodeType]*) - extends RecursiveTask[List[StartingPointWithSource]] { - - private val log = LoggerFactory.getLogger(this.getClass) - - override def compute(): List[StartingPointWithSource] = { - val sources: List[StoredNode] = sourceTravs - .flatMap(_.iterator.toList) - .collect { case n: StoredNode => n } - .dedup - .toList - .sortBy(_.id) - val tasks = sources.map(src => (src, new SourceToStartingPoints(src).fork())) - tasks.flatMap { case (src, t: ForkJoinTask[List[CfgNode]]) => - Try(t.get()) match { - case Failure(e) => log.error("Unable to complete 'SourceToStartingPoints' task", e); List() - case Success(sources) => sources.map(s => StartingPointWithSource(s, src)) - } - } - } -} - -/** The code below deals with member variables, and specifically with the situation where literals that initialize - * static members are passed to `reachableBy` as sources. In this case, we determine the first usages of this member in - * each method, traversing the AST from left to right. This isn't fool-proof, e.g., goto-statements would be - * problematic, but it works quite well in practice. + extends RecursiveTask[List[StartingPointWithSource]]: + + private val log = LoggerFactory.getLogger(this.getClass) + + override def compute(): List[StartingPointWithSource] = + val sources: List[StoredNode] = sourceTravs + .flatMap(_.iterator.toList) + .collect { case n: StoredNode => n } + .dedup + .toList + .sortBy(_.id) + val tasks = sources.map(src => (src, new SourceToStartingPoints(src).fork())) + tasks.flatMap { case (src, t: ForkJoinTask[List[CfgNode]]) => + Try(t.get()) match + case Failure(e) => + log.error("Unable to complete 'SourceToStartingPoints' task", e); List() + case Success(sources) => sources.map(s => StartingPointWithSource(s, src)) + } +end SourceTravsToStartingPointsTask + +/** The code below deals with member variables, and specifically with the situation where literals + * that initialize static members are passed to `reachableBy` as sources. In this case, we + * determine the first usages of this member in each method, traversing the AST from left to right. + * This isn't fool-proof, e.g., goto-statements would be problematic, but it works quite well in + * practice. */ -class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode]] { - - private val cpg = Cpg(src.graph()) - - override def compute(): List[CfgNode] = sourceToStartingPoints(src) - - private def sourceToStartingPoints(src: StoredNode): List[CfgNode] = { - src match { - case methodReturn: MethodReturn => - methodReturn.method.callIn.l - case lit: Literal => - List(lit) ++ usages(targetsToClassIdentifierPair(literalToInitializedMembers(lit))) ++ globalFromLiteral(lit) - case member: Member => - usages(targetsToClassIdentifierPair(List(member))) - case x: Declaration => - List(x).collectAll[CfgNode].toList - case x: Identifier => - (withFieldAndIndexAccesses( - List(x).collectAll[CfgNode].toList ++ x.refsTo.collectAll[Local].flatMap(sourceToStartingPoints) - ) ++ x.refsTo.capturedByMethodRef.referencedMethod.flatMap(m => usagesForName(x.name, m))).flatMap { - case x: Call => sourceToStartingPoints(x) - case x => List(x) +class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode]]: + + private val cpg = Cpg(src.graph()) + + override def compute(): List[CfgNode] = sourceToStartingPoints(src) + + private def sourceToStartingPoints(src: StoredNode): List[CfgNode] = + src match + case methodReturn: MethodReturn => + methodReturn.method.callIn.l + case lit: Literal => + List(lit) ++ usages( + targetsToClassIdentifierPair(literalToInitializedMembers(lit)) + ) ++ globalFromLiteral(lit) + case member: Member => + usages(targetsToClassIdentifierPair(List(member))) + case x: Declaration => + List(x).collectAll[CfgNode].toList + case x: Identifier => + (withFieldAndIndexAccesses( + List(x).collectAll[CfgNode].toList ++ x.refsTo.collectAll[Local].flatMap( + sourceToStartingPoints + ) + ) ++ x.refsTo.capturedByMethodRef.referencedMethod.flatMap(m => + usagesForName(x.name, m) + )).flatMap { + case x: Call => sourceToStartingPoints(x) + case x => List(x) + } + case x: Call => + (x._receiverIn.l :+ x).collect { case y: CfgNode => y } + case x => List(x).collect { case y: CfgNode => y } + + private def withFieldAndIndexAccesses(nodes: List[CfgNode]): List[CfgNode] = + nodes.flatMap { + case identifier: Identifier => + List(identifier) ++ fieldAndIndexAccesses(identifier) + case x => List(x) } - case x: Call => - (x._receiverIn.l :+ x).collect { case y: CfgNode => y } - case x => List(x).collect { case y: CfgNode => y } - } - } - - private def withFieldAndIndexAccesses(nodes: List[CfgNode]): List[CfgNode] = - nodes.flatMap { - case identifier: Identifier => - List(identifier) ++ fieldAndIndexAccesses(identifier) - case x => List(x) - } - - private def fieldAndIndexAccesses(identifier: Identifier): List[CfgNode] = - identifier.method._identifierViaContainsOut - .nameExact(identifier.name) - .inCall - .collect { case c if isFieldAccess(c.name) => c } - .l - - private def usages(pairs: List[(TypeDecl, AstNode)]): List[CfgNode] = { - pairs.flatMap { case (typeDecl, astNode) => - val nonConstructorMethods = methodsRecursively(typeDecl).iterator - .whereNot(_.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName, "__init__")) - .l - - val usagesInSameClass = - nonConstructorMethods.flatMap { m => firstUsagesOf(astNode, m, typeDecl) } - - val usagesInOtherClasses = cpg.method.flatMap { m => - m.fieldAccess - .where(_.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName)) - .where { x => - astNode match { - case identifier: Identifier => - x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name) - case fieldIdentifier: FieldIdentifier => - x.argument(2).isFieldIdentifier.canonicalNameExact(fieldIdentifier.canonicalName) - case _ => Iterator.empty - } - } - .takeWhile(notLeftHandOfAssignment) - .headOption - }.l - usagesInSameClass ++ usagesInOtherClasses - } - } - - /** For given method, determine the first usage of the given expression. - */ - private def firstUsagesOf(astNode: AstNode, m: Method, typeDecl: TypeDecl): List[Expression] = { - astNode match { - case member: Member => - usagesForName(member.name, m) - case identifier: Identifier => - usagesForName(identifier.name, m) - case fieldIdentifier: FieldIdentifier => + + private def fieldAndIndexAccesses(identifier: Identifier): List[CfgNode] = + identifier.method._identifierViaContainsOut + .nameExact(identifier.name) + .inCall + .collect { case c if isFieldAccess(c.name) => c } + .l + + private def usages(pairs: List[(TypeDecl, AstNode)]): List[CfgNode] = + pairs.flatMap { case (typeDecl, astNode) => + val nonConstructorMethods = methodsRecursively(typeDecl).iterator + .whereNot(_.nameExact( + Defines.StaticInitMethodName, + Defines.ConstructorMethodName, + "__init__" + )) + .l + + val usagesInSameClass = + nonConstructorMethods.flatMap { m => firstUsagesOf(astNode, m, typeDecl) } + + val usagesInOtherClasses = cpg.method.flatMap { m => + m.fieldAccess + .where(_.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName)) + .where { x => + astNode match + case identifier: Identifier => + x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name) + case fieldIdentifier: FieldIdentifier => + x.argument(2).isFieldIdentifier.canonicalNameExact( + fieldIdentifier.canonicalName + ) + case _ => Iterator.empty + } + .takeWhile(notLeftHandOfAssignment) + .headOption + }.l + usagesInSameClass ++ usagesInOtherClasses + } + + /** For given method, determine the first usage of the given expression. + */ + private def firstUsagesOf(astNode: AstNode, m: Method, typeDecl: TypeDecl): List[Expression] = + astNode match + case member: Member => + usagesForName(member.name, m) + case identifier: Identifier => + usagesForName(identifier.name, m) + case fieldIdentifier: FieldIdentifier => + val fieldIdentifiers = + m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l + fieldIdentifiers + .canonicalNameExact(fieldIdentifier.canonicalName) + .inFieldAccess + // TODO `isIdentifier` seems to limit us here + .where(_.argument(1).isIdentifier.or( + _.nameExact("this", "self"), + _.typeFullNameExact(typeDecl.fullName) + )) + .takeWhile(notLeftHandOfAssignment) + .l + case _ => List() + + private def usagesForName(name: String, m: Method): List[Expression] = + val identifiers = m.ast.isIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l + val identifierUsages = identifiers.nameExact(name).takeWhile(notLeftHandOfAssignment).l val fieldIdentifiers = m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - fieldIdentifiers - .canonicalNameExact(fieldIdentifier.canonicalName) - .inFieldAccess - // TODO `isIdentifier` seems to limit us here - .where(_.argument(1).isIdentifier.or(_.nameExact("this", "self"), _.typeFullNameExact(typeDecl.fullName))) - .takeWhile(notLeftHandOfAssignment) - .l - case _ => List() - } - } - - private def usagesForName(name: String, m: Method): List[Expression] = { - val identifiers = m.ast.isIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - val identifierUsages = identifiers.nameExact(name).takeWhile(notLeftHandOfAssignment).l - val fieldIdentifiers = m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - val thisRefs = Seq("this", "self") ++ m.typeDecl.name.headOption.toList - val fieldAccessUsages = fieldIdentifiers.isFieldIdentifier - .canonicalNameExact(name) - .inFieldAccess - .where(_.argument(1).codeExact(thisRefs: _*)) - .takeWhile(notLeftHandOfAssignment) - .l - (identifierUsages ++ fieldAccessUsages).headOption.toList - } - - /** For a literal, determine if it is used in the initialization of any member variables. Return list of initialized - * members. An initialized member is either an identifier or a field-identifier. - */ - private def literalToInitializedMembers(lit: Literal): List[Expression] = - lit.inAssignment - .or( - _.method.nameExact(Defines.StaticInitMethodName, Defines.ConstructorMethodName, "__init__"), - // in language such as Python, where assignments for members can be directly under a type decl - _.method.typeDecl - ) - .target - .flatMap { - case identifier: Identifier - // If these are the same, then the parent method is the module-level type - if Option(identifier.method.fullName) == identifier.method.typeDecl.fullName.headOption || - // If a member shares the name of the identifier then we consider this as a member - lit.method.typeDecl.member.name.toSet.contains(identifier.name) => - List(identifier) - case call: Call if call.name == Operators.fieldAccess => call.ast.isFieldIdentifier.l - case _ => List[Expression]() - } - .l - - private def methodsRecursively(typeDecl: TypeDecl): List[Method] = { - def methods(x: AstNode): List[Method] = { - x match { - case m: Method => m :: m.astMinusRoot.isMethod.flatMap(methods).l - case _ => List() - } - } - typeDecl.method.flatMap(methods).l - } - - private def isTargetInAssignment(identifier: Identifier): List[Identifier] = { - identifier.start.argumentIndex(1).where(_.inAssignment).l - } - - private def notLeftHandOfAssignment(x: Expression): Boolean = { - !(x.argumentIndex == 1 && x.inCall.exists(y => allAssignmentTypes.contains(y.name))) - } - - private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = { - targets.flatMap { - case expr: Expression => - expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } - case member: Member => - member.typeDecl.map { typeDecl => (typeDecl, member) } - } - } - -} + val thisRefs = Seq("this", "self") ++ m.typeDecl.name.headOption.toList + val fieldAccessUsages = fieldIdentifiers.isFieldIdentifier + .canonicalNameExact(name) + .inFieldAccess + .where(_.argument(1).codeExact(thisRefs*)) + .takeWhile(notLeftHandOfAssignment) + .l + (identifierUsages ++ fieldAccessUsages).headOption.toList + + /** For a literal, determine if it is used in the initialization of any member variables. Return + * list of initialized members. An initialized member is either an identifier or a + * field-identifier. + */ + private def literalToInitializedMembers(lit: Literal): List[Expression] = + lit.inAssignment + .or( + _.method.nameExact( + Defines.StaticInitMethodName, + Defines.ConstructorMethodName, + "__init__" + ), + // in language such as Python, where assignments for members can be directly under a type decl + _.method.typeDecl + ) + .target + .flatMap { + case identifier: Identifier + // If these are the same, then the parent method is the module-level type + if Option( + identifier.method.fullName + ) == identifier.method.typeDecl.fullName.headOption || + // If a member shares the name of the identifier then we consider this as a member + lit.method.typeDecl.member.name.toSet.contains(identifier.name) => + List(identifier) + case call: Call if call.name == Operators.fieldAccess => + call.ast.isFieldIdentifier.l + case _ => List[Expression]() + } + .l + + private def methodsRecursively(typeDecl: TypeDecl): List[Method] = + def methods(x: AstNode): List[Method] = + x match + case m: Method => m :: m.astMinusRoot.isMethod.flatMap(methods).l + case _ => List() + typeDecl.method.flatMap(methods).l + + private def isTargetInAssignment(identifier: Identifier): List[Identifier] = + identifier.start.argumentIndex(1).where(_.inAssignment).l + + private def notLeftHandOfAssignment(x: Expression): Boolean = + !(x.argumentIndex == 1 && x.inCall.exists(y => allAssignmentTypes.contains(y.name))) + + private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = + targets.flatMap { + case expr: Expression => + expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } + case member: Member => + member.typeDecl.map { typeDecl => (typeDecl, member) } + } +end SourceToStartingPoints diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala index 7d0411d5..7902fe1e 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala @@ -3,13 +3,13 @@ package io.appthreat.dataflowengineoss.queryengine import io.appthreat.dataflowengineoss.queryengine.Engine.argToOutputParams import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{ - Call, - Expression, - Method, - MethodParameterIn, - MethodParameterOut, - MethodRef, - Return + Call, + Expression, + Method, + MethodParameterIn, + MethodParameterOut, + MethodRef, + Return } import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.NoResolve @@ -17,174 +17,200 @@ import org.slf4j.{Logger, LoggerFactory} /** Creation of new tasks from results of completed tasks. */ -class TaskCreator(context: EngineContext) { - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - /** For a given list of results and sources, generate new tasks. - */ - def createFromResults(results: Vector[ReachableByResult]): Vector[ReachableByTask] = { - val newTasks = tasksForParams(results) ++ tasksForUnresolvedOutArgs(results) - removeTasksWithLoopsAndTooHighCallDepth(newTasks) - } - - private def removeTasksWithLoopsAndTooHighCallDepth(tasks: Vector[ReachableByTask]): Vector[ReachableByTask] = { - val tasksWithValidCallDepth = if (context.config.maxCallDepth == -1) { - tasks - } else { - tasks.filter(_.callDepth <= context.config.maxCallDepth) - } - tasksWithValidCallDepth.filter { t => - t.taskStack.dedup.size == t.taskStack.size - } - } - - /** Create new tasks from all results that start in a parameter. In essence, we want to traverse to corresponding - * arguments of call sites, but we need to be careful here not to create unrealizable paths. We achieve this by - * holding a call stack in results. - * - * Case 1: we expanded into a callee that we identified on the way, e.g., a method `y = transform(x)`, and we have - * reached the parameter of that method (`transform`). Upon doing so, we recorded the call site that we expanded in - * `result.callSite`. We would now like to continue exploring from the corresponding argument at that call site only. - * - * Case 2: walking backward from the sink, we have only expanded into callers so far, that is, the call stack is - * empty. In this case, the next tasks need to explore each call site to the method. - */ - private def tasksForParams(results: Vector[ReachableByResult]): Vector[ReachableByTask] = { - startsAtParameter(results).flatMap { result => - val param = result.path.head.node.asInstanceOf[MethodParameterIn] - result.callSiteStack match { - case callSite :: tail => - // Case 1 - paramToArgs(param).filter(x => x.inCall.exists(c => c == callSite)).map { arg => - ReachableByTask(result.taskStack :+ TaskFingerprint(arg, tail, result.callDepth - 1), result.path) - } - case _ => - // Case 2 - paramToArgs(param).map { arg => - ReachableByTask(result.taskStack :+ TaskFingerprint(arg, List(), result.callDepth + 1), result.path) - } - } - } - } - - /** Returns only those results that start at a parameter node. - */ - private def startsAtParameter(results: Vector[ReachableByResult]) = { - results.collect { case r: ReachableByResult if r.path.head.node.isInstanceOf[MethodParameterIn] => r } - } - - /** For a given parameter of a method, determine all corresponding arguments at all call sites to the method. - */ - private def paramToArgs(param: MethodParameterIn): List[Expression] = { - val args = paramToArgsOfCallers(param) ++ paramToMethodRefCallReceivers(param) - if (args.size > context.config.maxArgsToAllow) { - logger.warn(s"Too many arguments for parameter: ${args.size}. Not expanding") - logger.warn("Method name: " + param.method.fullName) - List() - } else { - args - } - } - - private def paramToArgsOfCallers(param: MethodParameterIn): List[Expression] = - NoResolve - .getMethodCallsites(param.method) - .collectAll[Call] - .argument(param.index) - .l - - /** Expand to receiver objects of calls that reference the method of the parameter, e.g., if `param` is a parameter of - * `m`, return `foo` in `foo.bar(m)` TODO: I'm not sure whether `methodRef.methodFullNameExact(...)` uses an index. - * If not, then caching these lookups or keeping a map of all method names to their references may make sense. - */ - - private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = - new Cpg(param.graph()).methodRef.methodFullNameExact(param.method.fullName).inCall.argument(0).l - - /** Create new tasks from all results that end in an output argument, including return arguments. In this case, we - * want to traverse to corresponding method output parameters and method return nodes respectively. - */ - private def tasksForUnresolvedOutArgs(results: Vector[ReachableByResult]): Vector[ReachableByTask] = { - - val outArgsAndCalls = results - .map(x => (x, x.outputArgument, x.path, x.callDepth)) - .distinct - - val forCalls = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => - outArg.toList.flatMap { - case call: Call => - val methodReturns = call.toList - .flatMap(x => NoResolve.getCalledMethods(x).methodReturn.map(y => (x, y))) - .iterator - - methodReturns.flatMap { case (call, methodReturn) => - val method = methodReturn.method - val returnStatements = methodReturn._reachingDefIn.toList.collect { case r: Return => r } - if (method.isExternal || method.start.isStub.nonEmpty) { - val newPath = path - (call.receiver.l ++ call.argument.l).map { arg => - val taskStack = result.taskStack :+ TaskFingerprint(arg, result.callSiteStack, callDepth) - ReachableByTask(taskStack, newPath) - } - } else { - returnStatements.map { returnStatement => - val newPath = Vector(PathElement(methodReturn, result.callSiteStack)) ++ path - val taskStack = - result.taskStack :+ TaskFingerprint(returnStatement, call :: result.callSiteStack, callDepth + 1) - ReachableByTask(taskStack, newPath) - } +class TaskCreator(context: EngineContext): + + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + /** For a given list of results and sources, generate new tasks. + */ + def createFromResults(results: Vector[ReachableByResult]): Vector[ReachableByTask] = + val newTasks = tasksForParams(results) ++ tasksForUnresolvedOutArgs(results) + removeTasksWithLoopsAndTooHighCallDepth(newTasks) + + private def removeTasksWithLoopsAndTooHighCallDepth(tasks: Vector[ReachableByTask]) + : Vector[ReachableByTask] = + val tasksWithValidCallDepth = if context.config.maxCallDepth == -1 then + tasks + else + tasks.filter(_.callDepth <= context.config.maxCallDepth) + tasksWithValidCallDepth.filter { t => + t.taskStack.dedup.size == t.taskStack.size + } + + /** Create new tasks from all results that start in a parameter. In essence, we want to traverse + * to corresponding arguments of call sites, but we need to be careful here not to create + * unrealizable paths. We achieve this by holding a call stack in results. + * + * Case 1: we expanded into a callee that we identified on the way, e.g., a method `y = + * transform(x)`, and we have reached the parameter of that method (`transform`). Upon doing + * so, we recorded the call site that we expanded in `result.callSite`. We would now like to + * continue exploring from the corresponding argument at that call site only. + * + * Case 2: walking backward from the sink, we have only expanded into callers so far, that is, + * the call stack is empty. In this case, the next tasks need to explore each call site to the + * method. + */ + private def tasksForParams(results: Vector[ReachableByResult]): Vector[ReachableByTask] = + startsAtParameter(results).flatMap { result => + val param = result.path.head.node.asInstanceOf[MethodParameterIn] + result.callSiteStack match + case callSite :: tail => + // Case 1 + paramToArgs(param).filter(x => x.inCall.exists(c => c == callSite)).map { arg => + ReachableByTask( + result.taskStack :+ TaskFingerprint(arg, tail, result.callDepth - 1), + result.path + ) + } + case _ => + // Case 2 + paramToArgs(param).map { arg => + ReachableByTask( + result.taskStack :+ TaskFingerprint(arg, List(), result.callDepth + 1), + result.path + ) + } + } + + /** Returns only those results that start at a parameter node. + */ + private def startsAtParameter(results: Vector[ReachableByResult]) = + results.collect { + case r: ReachableByResult if r.path.head.node.isInstanceOf[MethodParameterIn] => r + } + + /** For a given parameter of a method, determine all corresponding arguments at all call sites + * to the method. + */ + private def paramToArgs(param: MethodParameterIn): List[Expression] = + val args = paramToArgsOfCallers(param) ++ paramToMethodRefCallReceivers(param) + if args.size > context.config.maxArgsToAllow then + logger.warn(s"Too many arguments for parameter: ${args.size}. Not expanding") + logger.warn("Method name: " + param.method.fullName) + List() + else + args + + private def paramToArgsOfCallers(param: MethodParameterIn): List[Expression] = + NoResolve + .getMethodCallsites(param.method) + .collectAll[Call] + .argument(param.index) + .l + + /** Expand to receiver objects of calls that reference the method of the parameter, e.g., if + * `param` is a parameter of `m`, return `foo` in `foo.bar(m)` TODO: I'm not sure whether + * `methodRef.methodFullNameExact(...)` uses an index. If not, then caching these lookups or + * keeping a map of all method names to their references may make sense. + */ + + private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = + new Cpg(param.graph()).methodRef.methodFullNameExact(param.method.fullName).inCall.argument( + 0 + ).l + + /** Create new tasks from all results that end in an output argument, including return + * arguments. In this case, we want to traverse to corresponding method output parameters and + * method return nodes respectively. + */ + private def tasksForUnresolvedOutArgs(results: Vector[ReachableByResult]) + : Vector[ReachableByTask] = + + val outArgsAndCalls = results + .map(x => (x, x.outputArgument, x.path, x.callDepth)) + .distinct + + val forCalls = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => + outArg.toList.flatMap { + case call: Call => + val methodReturns = call.toList + .flatMap(x => NoResolve.getCalledMethods(x).methodReturn.map(y => (x, y))) + .iterator + + methodReturns.flatMap { case (call, methodReturn) => + val method = methodReturn.method + val returnStatements = + methodReturn._reachingDefIn.toList.collect { case r: Return => r } + if method.isExternal || method.start.isStub.nonEmpty then + val newPath = path + (call.receiver.l ++ call.argument.l).map { arg => + val taskStack = result.taskStack :+ TaskFingerprint( + arg, + result.callSiteStack, + callDepth + ) + ReachableByTask(taskStack, newPath) + } + else + returnStatements.map { returnStatement => + val newPath = + Vector(PathElement(methodReturn, result.callSiteStack)) ++ path + val taskStack = + result.taskStack :+ TaskFingerprint( + returnStatement, + call :: result.callSiteStack, + callDepth + 1 + ) + ReachableByTask(taskStack, newPath) + } + end if + } + case _ => Vector.empty } - } - case _ => Vector.empty - } - } - - val forArgs = outArgsAndCalls.flatMap { case (result, args, path, callDepth) => - args.toList.flatMap { - case arg: Expression => - val outParams = if (result.callSiteStack.nonEmpty) { - List[MethodParameterOut]() - } else { - argToOutputParams(arg).l - } - outParams - .filterNot(_.method.isExternal) - .map { p => - val newStack = - arg.inCall.headOption.map { x => x :: result.callSiteStack }.getOrElse(result.callSiteStack) - ReachableByTask(result.taskStack :+ TaskFingerprint(p, newStack, callDepth + 1), path) + } + + val forArgs = outArgsAndCalls.flatMap { case (result, args, path, callDepth) => + args.toList.flatMap { + case arg: Expression => + val outParams = if result.callSiteStack.nonEmpty then + List[MethodParameterOut]() + else + argToOutputParams(arg).l + outParams + .filterNot(_.method.isExternal) + .map { p => + val newStack = + arg.inCall.headOption.map { x => + x :: result.callSiteStack + }.getOrElse(result.callSiteStack) + ReachableByTask( + result.taskStack :+ TaskFingerprint(p, newStack, callDepth + 1), + path + ) + } + case _ => Vector.empty } - case _ => Vector.empty - } - } - - val forMethodRefs = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => - outArg.toList.flatMap { - case methodRef: MethodRef => - val methodReturns = methodRef._refOut.collectAll[Method].methodReturn - methodReturns.flatMap { methodReturn => - val returnStatements = methodReturn._reachingDefIn.toList.collect { case r: Return => r } - returnStatements.map { returnStatement => - val newPath = Vector(PathElement(methodReturn, result.callSiteStack)) ++ path - val taskStack = - result.taskStack :+ TaskFingerprint(returnStatement, result.callSiteStack, callDepth + 1) - ReachableByTask(taskStack, newPath) + } + + val forMethodRefs = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => + outArg.toList.flatMap { + case methodRef: MethodRef => + val methodReturns = methodRef._refOut.collectAll[Method].methodReturn + methodReturns.flatMap { methodReturn => + val returnStatements = + methodReturn._reachingDefIn.toList.collect { case r: Return => r } + returnStatements.map { returnStatement => + val newPath = + Vector(PathElement(methodReturn, result.callSiteStack)) ++ path + val taskStack = + result.taskStack :+ TaskFingerprint( + returnStatement, + result.callSiteStack, + callDepth + 1 + ) + ReachableByTask(taskStack, newPath) + } + } + case _ => Vector.empty } - } - case _ => Vector.empty - } - } - restrictSize(forCalls) ++ restrictSize(forArgs) ++ restrictSize(forMethodRefs) - } - - private def restrictSize(l: Vector[ReachableByTask]): Vector[ReachableByTask] = { - if (l.size <= context.config.maxOutputArgsExpansion) { - l - } else { - logger.warn("Too many new tasks in expansion of unresolved output arguments") - Vector() - } - } - -} + } + restrictSize(forCalls) ++ restrictSize(forArgs) ++ restrictSize(forMethodRefs) + end tasksForUnresolvedOutArgs + + private def restrictSize(l: Vector[ReachableByTask]): Vector[ReachableByTask] = + if l.size <= context.config.maxOutputArgsExpansion then + l + else + logger.warn("Too many new tasks in expansion of unresolved output arguments") + Vector() +end TaskCreator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala index 28e8f412..f3b366c8 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala @@ -1,7 +1,10 @@ package io.appthreat.dataflowengineoss.queryengine import io.appthreat.dataflowengineoss.semanticsloader.Semantics -import io.appthreat.dataflowengineoss.queryengine.QueryEngineStatistics.{PATH_CACHE_HITS, PATH_CACHE_MISSES} +import io.appthreat.dataflowengineoss.queryengine.QueryEngineStatistics.{ + PATH_CACHE_HITS, + PATH_CACHE_MISSES +} import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.{toCfgNodeMethods, toExpressionMethods, *} @@ -11,8 +14,8 @@ import scala.collection.mutable /** Callable for solving a ReachableByTask * - * A Java Callable is "a task that returns a result and may throw an exception", and this is the callable for - * calculating the result for `task`. + * A Java Callable is "a task that returns a result and may throw an exception", and this is the + * callable for calculating the result for `task`. * * @param task * the data flow problem to solve @@ -21,195 +24,216 @@ import scala.collection.mutable * @param sources * the set of sources that we are looking to reach. */ -class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[CfgNode]) extends Callable[TaskSummary] { - - import Engine._ - - /** Entry point of callable. First checks if the maximum call depth has been exceeded, in which case an empty result - * list is returned. Otherwise, the task is solved and its results are returned. - */ - override def call(): TaskSummary = { - implicit val sem: Semantics = context.semantics - val path = Vector(PathElement(task.sink, task.callSiteStack)) - val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map() - results(task.sink, path, table, task.callSiteStack) - // TODO why do we update the call depth here? - val finalResults = table.get(task.fingerprint).get.map { r => - r.copy( - taskStack = r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth), - path = r.path ++ task.initialPath - ) - } - val (partial, complete) = finalResults.partition(_.partial) - val newTasks = new TaskCreator(context).createFromResults(partial) - TaskSummary(complete.flatMap(r => resultToTableEntries(r)), newTasks) - } - - private def resultToTableEntries(r: ReachableByResult): List[(TaskFingerprint, TableEntry)] = { - r.taskStack.indices.map { i => - val parentTask = r.taskStack(i) - val pathToSink = r.path.slice(0, r.path.map(_.node).indexOf(parentTask.sink)) - val newPath = pathToSink :+ PathElement(parentTask.sink, parentTask.callSiteStack) - (parentTask, TableEntry(path = newPath)) - }.toList - } - - /** Recursively expand the DDG backwards and return a list of all results, given by at least a source node in - * `sourceSymbols` and the path between the source symbol and the sink. - * - * This method stays within the method (intra-procedural analysis) and terminates at method parameters and at output - * arguments. - * - * @param path - * This is a path from a node to the sink. The first node of the path is expanded by this method - * - * @param table - * The result table is a cache of known results that we can re-use - * - * @param callSiteStack - * This stack holds all call sites we expanded to arrive at the generation of the current task - */ - private def results[NodeType <: CfgNode]( - sink: CfgNode, - path: Vector[PathElement], - table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], - callSiteStack: List[Call] - )(implicit semantics: Semantics): Vector[ReachableByResult] = { - - val curNode = path.head.node - - /** For each parent of the current node, determined via `expandIn`, check if results are available in the result - * table. If not, determine results recursively. +class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[CfgNode]) + extends Callable[TaskSummary]: + + import Engine.* + + /** Entry point of callable. First checks if the maximum call depth has been exceeded, in which + * case an empty result list is returned. Otherwise, the task is solved and its results are + * returned. */ - def computeResultsForParents() = { - deduplicateWithinTask(expandIn(curNode.asInstanceOf[CfgNode], path, callSiteStack).iterator.flatMap { parent => - createResultsFromCacheOrCompute(parent, path) - }.toVector) - } - - def deduplicateWithinTask(vec: Vector[ReachableByResult]): Vector[ReachableByResult] = { - vec - .groupBy { result => - val head = result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - val last = result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - (head, last, result.partial, result.callDepth) + override def call(): TaskSummary = + implicit val sem: Semantics = context.semantics + val path = Vector(PathElement(task.sink, task.callSiteStack)) + val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map() + results(task.sink, path, table, task.callSiteStack) + // TODO why do we update the call depth here? + val finalResults = table.get(task.fingerprint).get.map { r => + r.copy( + taskStack = + r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth), + path = r.path ++ task.initialPath + ) } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)).toList - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match { - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - }).map(_._2) - - if (withMaxLength.length == 1) { - withMaxLength.head - } else { - withMaxLength.minBy { x => - x.callDepth.toString + " " + - x.taskStack - .map(x => x.sink.id.toString + ":" + x.callSiteStack.map(_.id).mkString("|")) - .toString + " " + x.path - .map(x => (x.node.id, x.callSiteStack.map(_.id), x.visible, x.isOutputArg, x.outEdgeLabel).toString) - .mkString("-") - } - } - } - .toVector - } - - def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = { - val cachedResult = createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) - if (cachedResult.isDefined) { - QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) - cachedResult.get - } else { - QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) - val newPath = elemToPrepend +: path - results(sink, newPath, table, callSiteStack) - } - } - - /** For a given path, determine whether results for the first element (`first`) are stored in the table, and if so, - * for each result, determine the path up to `first` and prepend it to `path`, giving us new results via table - * lookup. + val (partial, complete) = finalResults.partition(_.partial) + val newTasks = new TaskCreator(context).createFromResults(partial) + TaskSummary(complete.flatMap(r => resultToTableEntries(r)), newTasks) + + private def resultToTableEntries(r: ReachableByResult): List[(TaskFingerprint, TableEntry)] = + r.taskStack.indices.map { i => + val parentTask = r.taskStack(i) + val pathToSink = r.path.slice(0, r.path.map(_.node).indexOf(parentTask.sink)) + val newPath = pathToSink :+ PathElement(parentTask.sink, parentTask.callSiteStack) + (parentTask, TableEntry(path = newPath)) + }.toList + + /** Recursively expand the DDG backwards and return a list of all results, given by at least a + * source node in `sourceSymbols` and the path between the source symbol and the sink. + * + * This method stays within the method (intra-procedural analysis) and terminates at method + * parameters and at output arguments. + * + * @param path + * This is a path from a node to the sink. The first node of the path is expanded by this + * method + * + * @param table + * The result table is a cache of known results that we can re-use + * + * @param callSiteStack + * This stack holds all call sites we expanded to arrive at the generation of the current + * task */ - def createFromTable( + private def results[NodeType <: CfgNode]( + sink: CfgNode, + path: Vector[PathElement], table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], - first: PathElement, - callSiteStack: List[Call], - remainder: Vector[PathElement], - callDepth: Int - ): Option[Vector[ReachableByResult]] = { - table.get(TaskFingerprint(first.node.asInstanceOf[CfgNode], callSiteStack, callDepth)).map { res => - res.map { r => - val stopIndex = r.path.map(x => (x.node, x.callSiteStack)).indexOf((first.node, first.callSiteStack)) - val pathToFirstNode = r.path.slice(0, stopIndex) - val completePath = pathToFirstNode ++ (first +: remainder) - r.copy(path = Vector(completePath.head) ++ completePath.tail) - } - } - } - - def createPartialResultForOutputArgOrRet() = { - Vector( - ReachableByResult( - task.taskStack, - PathElement(path.head.node, callSiteStack, isOutputArg = true) +: path.tail, - partial = true - ) - ) - } - - /** Determine results for the current node - */ - val res = curNode match { - // Case 1: we have reached a source => return result and continue traversing (expand into parents) - case x if sources.contains(x.asInstanceOf[NodeType]) => - if (x.isInstanceOf[MethodParameterIn]) { - Vector( - ReachableByResult(task.taskStack, path), - ReachableByResult(task.taskStack, path, partial = true) - ) ++ computeResultsForParents() - } else { - Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents() - } - // Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing - case _: MethodParameterIn => - Vector(ReachableByResult(task.taskStack, path, partial = true)) - // Case 3: we have reached a call to an internal method without semantic (return value) and - // this isn't the start node => return partial result and stop traversing - case call: Call - if isCallToInternalMethodWithoutSemantic(call) - && !isArgOrRetOfMethodWeCameFrom(call, path) => - createPartialResultForOutputArgOrRet() - - // Case 4: we have reached an argument to an internal method without semantic (output argument) and - // this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing - case arg: Expression - if path.size > 1 - && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) - && !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) => - createPartialResultForOutputArgOrRet() - - case _: MethodRef => createPartialResultForOutputArgOrRet() - - // All other cases: expand into parents - case _ => - computeResultsForParents() - } - val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth) - table.updateWith(key) { - case Some(existingValue) => Some(existingValue ++ res) - case None => Some(res) - } - res - } - - private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean = - path match { - case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => methodsForCall(call).contains(x.method) - case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => methodsForCall(call).contains(x.method) - case _ => false - } + callSiteStack: List[Call] + )(implicit semantics: Semantics): Vector[ReachableByResult] = + + val curNode = path.head.node + + /** For each parent of the current node, determined via `expandIn`, check if results are + * available in the result table. If not, determine results recursively. + */ + def computeResultsForParents() = + deduplicateWithinTask(expandIn( + curNode.asInstanceOf[CfgNode], + path, + callSiteStack + ).iterator.flatMap { parent => + createResultsFromCacheOrCompute(parent, path) + }.toVector) + + def deduplicateWithinTask(vec: Vector[ReachableByResult]): Vector[ReachableByResult] = + vec + .groupBy { result => + val head = result.path.headOption.map(x => + (x.node, x.callSiteStack, x.isOutputArg) + ).get + val last = result.path.lastOption.map(x => + (x.node, x.callSiteStack, x.isOutputArg) + ).get + (head, last, result.partial, result.callDepth) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)).toList + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) + + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.callDepth.toString + " " + + x.taskStack + .map(x => + x.sink.id.toString + ":" + x.callSiteStack.map( + _.id + ).mkString("|") + ) + .toString + " " + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + end if + } + .toVector + + def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = + val cachedResult = + createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) + if cachedResult.isDefined then + QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) + cachedResult.get + else + QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) + val newPath = elemToPrepend +: path + results(sink, newPath, table, callSiteStack) + + /** For a given path, determine whether results for the first element (`first`) are stored + * in the table, and if so, for each result, determine the path up to `first` and prepend + * it to `path`, giving us new results via table lookup. + */ + def createFromTable( + table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], + first: PathElement, + callSiteStack: List[Call], + remainder: Vector[PathElement], + callDepth: Int + ): Option[Vector[ReachableByResult]] = + table.get( + TaskFingerprint(first.node.asInstanceOf[CfgNode], callSiteStack, callDepth) + ).map { res => + res.map { r => + val stopIndex = r.path.map(x => (x.node, x.callSiteStack)).indexOf(( + first.node, + first.callSiteStack + )) + val pathToFirstNode = r.path.slice(0, stopIndex) + val completePath = pathToFirstNode ++ (first +: remainder) + r.copy(path = Vector(completePath.head) ++ completePath.tail) + } + } -} + def createPartialResultForOutputArgOrRet() = + Vector( + ReachableByResult( + task.taskStack, + PathElement(path.head.node, callSiteStack, isOutputArg = true) +: path.tail, + partial = true + ) + ) + + /** Determine results for the current node + */ + val res = curNode match + // Case 1: we have reached a source => return result and continue traversing (expand into parents) + case x if sources.contains(x.asInstanceOf[NodeType]) => + if x.isInstanceOf[MethodParameterIn] then + Vector( + ReachableByResult(task.taskStack, path), + ReachableByResult(task.taskStack, path, partial = true) + ) ++ computeResultsForParents() + else + Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents() + // Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing + case _: MethodParameterIn => + Vector(ReachableByResult(task.taskStack, path, partial = true)) + // Case 3: we have reached a call to an internal method without semantic (return value) and + // this isn't the start node => return partial result and stop traversing + case call: Call + if isCallToInternalMethodWithoutSemantic(call) + && !isArgOrRetOfMethodWeCameFrom(call, path) => + createPartialResultForOutputArgOrRet() + + // Case 4: we have reached an argument to an internal method without semantic (output argument) and + // this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing + case arg: Expression + if path.size > 1 + && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) + && !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) => + createPartialResultForOutputArgOrRet() + + case _: MethodRef => createPartialResultForOutputArgOrRet() + + // All other cases: expand into parents + case _ => + computeResultsForParents() + val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth) + table.updateWith(key) { + case Some(existingValue) => Some(existingValue ++ res) + case None => Some(res) + } + res + end results + + private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean = + path match + case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => + methodsForCall(call).contains(x.method) + case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => + methodsForCall(call).contains(x.method) + case _ => false +end TaskSolver diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala index d5143e61..646a96cd 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala @@ -2,105 +2,111 @@ package io.appthreat.dataflowengineoss import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, CfgNode} -package object queryengine { +package object queryengine: - /** The TaskFingerprint uniquely identifies a task. - */ - case class TaskFingerprint(sink: CfgNode, callSiteStack: List[Call], callDepth: Int) - - /** A (partial) result, informing about a path that exists from a source to another node in the graph. - * - * @param taskStack - * The list of tasks that was solved to arrive at this task - * - * @param path - * A path to the sink. - * - * @param partial - * indicate whether this result stands on its own or requires further analysis, e.g., by expanding output arguments - * backwards into method output parameters. - */ - case class ReachableByResult(taskStack: List[TaskFingerprint], path: Vector[PathElement], partial: Boolean = false) { - - def fingerprint: TaskFingerprint = taskStack.last - def sink: CfgNode = fingerprint.sink - def callSiteStack: List[Call] = fingerprint.callSiteStack - - def callDepth: Int = fingerprint.callDepth - - def startingPoint: CfgNode = path.head.node.asInstanceOf[CfgNode] - - /** If the result begins in an output argument, return it. + /** The TaskFingerprint uniquely identifies a task. */ - def outputArgument: Option[CfgNode] = { - path.headOption.collect { - case elem: PathElement if elem.isOutputArg => - elem.node.asInstanceOf[CfgNode] - } - } - } - - /** We represent data flows as sequences of path elements, where each path element consists of a node, flags and the - * label of its outgoing edge. - * - * @param node - * The parent node. This is actually always a CfgNode during data flow computation, however, since the source may - * be an arbitrary AST node, we may add an AST node to the start of the flow right before returning flows to the - * user. - * - * @param callSiteStack - * The call stack when this path element was created. Since we may enter the same function via two different call - * sites, path elements should only be treated as the same if they are the same node and we've reached them via the - * same call sequence. - * - * @param visible - * whether this path element should be shown in the flow - * @param isOutputArg - * input and output arguments are the same node in the CPG, so, we need this additional flag to determine whether - * we are on an input or output argument. By default, we consider arguments to be input arguments, meaning that - * when tracking `x` at `f(x)`, we do not expand into `f` but rather upwards to producers of `x`. - * @param outEdgeLabel - * label of the outgoing DDG edge - */ - case class PathElement( - node: AstNode, - callSiteStack: List[Call] = List(), - visible: Boolean = true, - isOutputArg: Boolean = false, - outEdgeLabel: String = "" - ) - - /** @param taskStack - * The list of tasks that was solved to arrive at this task, including the current task, which is to be solved. The - * current task is the last element of the list. - * - * @param initialPath - * The path from the current sink downwards to previous sinks. - */ - case class ReachableByTask(taskStack: List[TaskFingerprint], initialPath: Vector[PathElement]) { - - /** This tasks fingerprint: if two tasks have the same fingerprint, then the TaskSolver MUST return the same result - * for them. This is the basis of our caching scheme. + case class TaskFingerprint(sink: CfgNode, callSiteStack: List[Call], callDepth: Int) + + /** A (partial) result, informing about a path that exists from a source to another node in the + * graph. + * + * @param taskStack + * The list of tasks that was solved to arrive at this task + * + * @param path + * A path to the sink. + * + * @param partial + * indicate whether this result stands on its own or requires further analysis, e.g., by + * expanding output arguments backwards into method output parameters. */ - def fingerprint: TaskFingerprint = taskStack.last - - /** The sink at which we start the analysis (upwards) - */ - def sink: CfgNode = fingerprint.sink - - /** The call sites we have expanded downwards during this analysis. We need to keep track of this so that we do not - * end up expanding one call site and then returning to a different call site, which would produce an unreachable - * path. + case class ReachableByResult( + taskStack: List[TaskFingerprint], + path: Vector[PathElement], + partial: Boolean = false + ): + + def fingerprint: TaskFingerprint = taskStack.last + def sink: CfgNode = fingerprint.sink + def callSiteStack: List[Call] = fingerprint.callSiteStack + + def callDepth: Int = fingerprint.callDepth + + def startingPoint: CfgNode = path.head.node.asInstanceOf[CfgNode] + + /** If the result begins in an output argument, return it. + */ + def outputArgument: Option[CfgNode] = + path.headOption.collect { + case elem: PathElement if elem.isOutputArg => + elem.node.asInstanceOf[CfgNode] + } + end ReachableByResult + + /** We represent data flows as sequences of path elements, where each path element consists of a + * node, flags and the label of its outgoing edge. + * + * @param node + * The parent node. This is actually always a CfgNode during data flow computation, however, + * since the source may be an arbitrary AST node, we may add an AST node to the start of the + * flow right before returning flows to the user. + * + * @param callSiteStack + * The call stack when this path element was created. Since we may enter the same function + * via two different call sites, path elements should only be treated as the same if they are + * the same node and we've reached them via the same call sequence. + * + * @param visible + * whether this path element should be shown in the flow + * @param isOutputArg + * input and output arguments are the same node in the CPG, so, we need this additional flag + * to determine whether we are on an input or output argument. By default, we consider + * arguments to be input arguments, meaning that when tracking `x` at `f(x)`, we do not + * expand into `f` but rather upwards to producers of `x`. + * @param outEdgeLabel + * label of the outgoing DDG edge */ - def callSiteStack: List[Call] = fingerprint.callSiteStack - - /** The call depth at which this task was created. + case class PathElement( + node: AstNode, + callSiteStack: List[Call] = List(), + visible: Boolean = true, + isOutputArg: Boolean = false, + outEdgeLabel: String = "" + ) + + /** @param taskStack + * The list of tasks that was solved to arrive at this task, including the current task, + * which is to be solved. The current task is the last element of the list. + * + * @param initialPath + * The path from the current sink downwards to previous sinks. */ - def callDepth: Int = fingerprint.callDepth - - } - - case class TaskSummary(tableEntries: Vector[(TaskFingerprint, TableEntry)], followupTasks: Vector[ReachableByTask]) - case class TableEntry(path: Vector[PathElement]) - -} + case class ReachableByTask(taskStack: List[TaskFingerprint], initialPath: Vector[PathElement]): + + /** This tasks fingerprint: if two tasks have the same fingerprint, then the TaskSolver MUST + * return the same result for them. This is the basis of our caching scheme. + */ + def fingerprint: TaskFingerprint = taskStack.last + + /** The sink at which we start the analysis (upwards) + */ + def sink: CfgNode = fingerprint.sink + + /** The call sites we have expanded downwards during this analysis. We need to keep track of + * this so that we do not end up expanding one call site and then returning to a different + * call site, which would produce an unreachable path. + */ + def callSiteStack: List[Call] = fingerprint.callSiteStack + + /** The call depth at which this task was created. + */ + def callDepth: Int = fingerprint.callDepth + end ReachableByTask + + case class TaskSummary( + tableEntries: Vector[(TaskFingerprint, TableEntry)], + followupTasks: Vector[ReachableByTask] + ) + case class TableEntry(path: Vector[PathElement]) +end queryengine diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala index fdc7b2ab..4e5a26d5 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala @@ -7,100 +7,99 @@ import org.antlr.v4.runtime.tree.ParseTreeWalker import org.antlr.v4.runtime.{CharStream, CharStreams, CommonTokenStream} import scala.collection.mutable -import scala.jdk.CollectionConverters._ - -object Semantics { - - def fromList(elements: List[FlowSemantic]): Semantics = { - new Semantics( - mutable.Map.newBuilder - .addAll(elements.map { e => - e.methodFullName -> e - }) - .result() - ) - } - - def empty: Semantics = fromList(List()) - -} - -class Semantics private (methodToSemantic: mutable.Map[String, FlowSemantic]) { - - /** The map below keeps a mapping between results of a regex and the regex string it matches. e.g. - * - * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` - */ - private val regexMatchedFullNames = mutable.HashMap.empty[String, String] - - /** Initialize all the method semantics that use regex with all their regex results before query time. - */ - def loadRegexSemantics(cpg: Cpg): Unit = { - import io.shiftleft.semanticcpg.language._ - - methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => - cpg.method.fullName(regexString).fullName.foreach { methodMatch => - regexMatchedFullNames.put(methodMatch, regexString) - } - } - } - - def elements: List[FlowSemantic] = methodToSemantic.values.toList - - def forMethod(fullName: String): Option[FlowSemantic] = regexMatchedFullNames.get(fullName) match { - case Some(matchedFullName) => methodToSemantic.get(matchedFullName) - case None => methodToSemantic.get(fullName) - } - - def serialize: String = { - elements - .sortBy(_.methodFullName) - .map { elem => - s"\"${elem.methodFullName}\" " + elem.mappings - .collect { case FlowMapping(x, y) => s"$x -> $y" } - .mkString(" ") - } - .mkString("\n") - } - -} -case class FlowSemantic(methodFullName: String, mappings: List[FlowPath] = List.empty, regex: Boolean = false) - -object FlowSemantic { - - def from(methodFullName: String, mappings: List[_], regex: Boolean = false): FlowSemantic = { - FlowSemantic( - methodFullName, - mappings.map { - case (src: Int, dst: Int) => FlowMapping(src, dst) - case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) - case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) - case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => FlowMapping(srcIdx, src, dstIdx, dst) - case x: FlowMapping => x - }, - regex - ) - } - -} +import scala.jdk.CollectionConverters.* + +object Semantics: + + def fromList(elements: List[FlowSemantic]): Semantics = + new Semantics( + mutable.Map.newBuilder + .addAll(elements.map { e => + e.methodFullName -> e + }) + .result() + ) + + def empty: Semantics = fromList(List()) + +class Semantics private (methodToSemantic: mutable.Map[String, FlowSemantic]): + + /** The map below keeps a mapping between results of a regex and the regex string it matches. + * e.g. + * + * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` + */ + private val regexMatchedFullNames = mutable.HashMap.empty[String, String] + + /** Initialize all the method semantics that use regex with all their regex results before query + * time. + */ + def loadRegexSemantics(cpg: Cpg): Unit = + import io.shiftleft.semanticcpg.language.* + + methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => + cpg.method.fullName(regexString).fullName.foreach { methodMatch => + regexMatchedFullNames.put(methodMatch, regexString) + } + } + + def elements: List[FlowSemantic] = methodToSemantic.values.toList + + def forMethod(fullName: String): Option[FlowSemantic] = + regexMatchedFullNames.get(fullName) match + case Some(matchedFullName) => methodToSemantic.get(matchedFullName) + case None => methodToSemantic.get(fullName) + + def serialize: String = + elements + .sortBy(_.methodFullName) + .map { elem => + s"\"${elem.methodFullName}\" " + elem.mappings + .collect { case FlowMapping(x, y) => s"$x -> $y" } + .mkString(" ") + } + .mkString("\n") +end Semantics + +case class FlowSemantic( + methodFullName: String, + mappings: List[FlowPath] = List.empty, + regex: Boolean = false +) + +object FlowSemantic: + + def from(methodFullName: String, mappings: List[?], regex: Boolean = false): FlowSemantic = + FlowSemantic( + methodFullName, + mappings.map { + case (src: Int, dst: Int) => FlowMapping(src, dst) + case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) + case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) + case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => + FlowMapping(srcIdx, src, dstIdx, dst) + case x: FlowMapping => x + }, + regex + ) abstract class FlowNode -/** Collects parameters and return nodes under a common trait. This trait acknowledges their argument index which is - * relevant when a caller wants to coordinate relevant tainted flows through specific arguments and the return flow. +/** Collects parameters and return nodes under a common trait. This trait acknowledges their + * argument index which is relevant when a caller wants to coordinate relevant tainted flows + * through specific arguments and the return flow. */ -trait ParamOrRetNode extends FlowNode { - - /** Temporary backward compatible idx field. - * - * @return - * the argument index. - */ - def index: Int -} - -/** A parameter where the index of the argument matches the position of the parameter at the callee. The name is used to - * match named arguments if used instead of positional arguments. +trait ParamOrRetNode extends FlowNode: + + /** Temporary backward compatible idx field. + * + * @return + * the argument index. + */ + def index: Int + +/** A parameter where the index of the argument matches the position of the parameter at the callee. + * The name is used to match named arguments if used instead of positional arguments. * * @param index * the position or argument index. @@ -109,9 +108,8 @@ trait ParamOrRetNode extends FlowNode { */ case class ParameterNode(index: Int, name: Option[String] = None) extends ParamOrRetNode -object ParameterNode { - def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) -} +object ParameterNode: + def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) /** Represents explicit mappings or special cases. */ @@ -126,86 +124,77 @@ sealed trait FlowPath */ case class FlowMapping(src: FlowNode, dst: FlowNode) extends FlowPath -object FlowMapping { - def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) - - def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) +object FlowMapping: + def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) - def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) + def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) - def apply(from: Int, toIdx: Int, to: String): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) + def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) -} + def apply(from: Int, toIdx: Int, to: String): FlowMapping = + FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) -/** Represents an instance where parameters are not sanitized, may affect the return value, and do not cross-taint. e.g. - * foo(1, 2) = 1 -> 1, 2 -> 2, 1 -> -1, 2 -> -1 +/** Represents an instance where parameters are not sanitized, may affect the return value, and do + * not cross-taint. e.g. foo(1, 2) = 1 -> 1, 2 -> 2, 1 -> -1, 2 -> -1 * - * The main benefit is that this works for unbounded parameters e.g. VARARGS. Note this does not taint 0 -> 0. + * The main benefit is that this works for unbounded parameters e.g. VARARGS. Note this does not + * taint 0 -> 0. */ object PassThroughMapping extends FlowPath -class Parser() { - - def parse(input: String): List[FlowSemantic] = { - val charStream = CharStreams.fromString(input) - parseCharStream(charStream) - } - - def parseFile(fileName: String): List[FlowSemantic] = { - val charStream = CharStreams.fromFileName(fileName) - parseCharStream(charStream) - } - - private def parseCharStream(charStream: CharStream): List[FlowSemantic] = { - val lexer = new SemanticsLexer(charStream) - val tokenStream = new CommonTokenStream(lexer) - val parser = new SemanticsParser(tokenStream) - val treeWalker = new ParseTreeWalker() +class Parser(): - val tree = parser.taintSemantics() - val listener = new Listener() - treeWalker.walk(listener, tree) - listener.result.toList - } + def parse(input: String): List[FlowSemantic] = + val charStream = CharStreams.fromString(input) + parseCharStream(charStream) - implicit class AntlrFlowExtensions(val ctx: MappingContext) { + def parseFile(fileName: String): List[FlowSemantic] = + val charStream = CharStreams.fromFileName(fileName) + parseCharStream(charStream) - def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined + private def parseCharStream(charStream: CharStream): List[FlowSemantic] = + val lexer = new SemanticsLexer(charStream) + val tokenStream = new CommonTokenStream(lexer) + val parser = new SemanticsParser(tokenStream) + val treeWalker = new ParseTreeWalker() - def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt + val tree = parser.taintSemantics() + val listener = new Listener() + treeWalker.walk(listener, tree) + listener.result.toList - def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) + implicit class AntlrFlowExtensions(val ctx: MappingContext): - def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt + def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined - def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) + def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt - } + def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) - private class Listener extends SemanticsBaseListener { + def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt - val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() + def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) - override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = { - ctx.singleSemantic().asScala.foreach { semantic => - val methodName = semantic.methodName().name().getText - val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) - result.addOne(FlowSemantic(methodName, mappings)) - } - } + private class Listener extends SemanticsBaseListener: - private def ctxToParamMapping(ctx: MappingContext): FlowPath = - if (ctx.isPassThrough) { - PassThroughMapping - } else { - val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) - val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) + val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() - FlowMapping(src, dst) - } + override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = + ctx.singleSemantic().asScala.foreach { semantic => + val methodName = semantic.methodName().name().getText + val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) + result.addOne(FlowSemantic(methodName, mappings)) + } - } + private def ctxToParamMapping(ctx: MappingContext): FlowPath = + if ctx.isPassThrough then + PassThroughMapping + else + val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) + val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) -} + FlowMapping(src, dst) + end Listener +end Parser diff --git a/macros/src/main/scala/io/appthreat/console/Query.scala b/macros/src/main/scala/io/appthreat/console/Query.scala index 110d0967..2e52d570 100644 --- a/macros/src/main/scala/io/appthreat/console/Query.scala +++ b/macros/src/main/scala/io/appthreat/console/Query.scala @@ -4,7 +4,10 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.StoredNode case class CodeSnippet(content: String, filename: String) -case class MultiFileCodeExamples(positive: List[List[CodeSnippet]], negative: List[List[CodeSnippet]]) +case class MultiFileCodeExamples( + positive: List[List[CodeSnippet]], + negative: List[List[CodeSnippet]] +) case class CodeExamples(positive: List[String], negative: List[String]) case class Query( @@ -13,7 +16,7 @@ case class Query( title: String, description: String, score: Double, - traversal: Cpg => Iterator[_ <: StoredNode], + traversal: Cpg => Iterator[? <: StoredNode], traversalAsString: String = "", tags: List[String] = List(), language: String = "", @@ -21,31 +24,30 @@ case class Query( multiFileCodeExamples: MultiFileCodeExamples = MultiFileCodeExamples(List(), List()) ) -object Query { - def make( - name: String, - author: String, - title: String, - description: String, - score: Double, - traversalWithStrRep: TraversalWithStrRep, - tags: List[String] = List(), - codeExamples: CodeExamples = CodeExamples(List(), List()), - multiFileCodeExamples: MultiFileCodeExamples = MultiFileCodeExamples(List(), List()) - ): Query = { - Query( - name = name, - author = author, - title = title, - description = description, - score = score, - traversal = traversalWithStrRep.traversal, - traversalAsString = traversalWithStrRep.strRep, - tags = tags, - codeExamples = codeExamples, - multiFileCodeExamples = multiFileCodeExamples - ) - } -} +object Query: + def make( + name: String, + author: String, + title: String, + description: String, + score: Double, + traversalWithStrRep: TraversalWithStrRep, + tags: List[String] = List(), + codeExamples: CodeExamples = CodeExamples(List(), List()), + multiFileCodeExamples: MultiFileCodeExamples = MultiFileCodeExamples(List(), List()) + ): Query = + Query( + name = name, + author = author, + title = title, + description = description, + score = score, + traversal = traversalWithStrRep.traversal, + traversalAsString = traversalWithStrRep.strRep, + tags = tags, + codeExamples = codeExamples, + multiFileCodeExamples = multiFileCodeExamples + ) +end Query -case class TraversalWithStrRep(traversal: Cpg => Iterator[_ <: StoredNode], strRep: String = "") +case class TraversalWithStrRep(traversal: Cpg => Iterator[? <: StoredNode], strRep: String = "") diff --git a/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala b/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala index 8bdb7b7e..3a74959a 100644 --- a/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala +++ b/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala @@ -5,99 +5,94 @@ import org.reflections8.util.{ClasspathHelper, ConfigurationBuilder} import java.lang.reflect.{Method, Parameter} import scala.annotation.unused -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* trait QueryBundle class QueryDatabase( defaultArgumentProvider: DefaultArgumentProvider = new DefaultArgumentProvider, namespace: String = "io.appthreat.scanners" -) { +): - /** Determine all bundles on the class path - */ - def allBundles: List[Class[_ <: QueryBundle]] = - new Reflections( - new ConfigurationBuilder().setUrls( - ClasspathHelper.forPackage(namespace, ClasspathHelper.contextClassLoader(), ClasspathHelper.staticClassLoader()) - ) - ).getSubTypesOf(classOf[QueryBundle]).asScala.toList - - /** Determine queries across all bundles - */ - def allQueries: List[Query] = { - allBundles.flatMap { bundle => - queriesInBundle(bundle) - } - } + /** Determine all bundles on the class path + */ + def allBundles: List[Class[? <: QueryBundle]] = + new Reflections( + new ConfigurationBuilder().setUrls( + ClasspathHelper.forPackage( + namespace, + ClasspathHelper.contextClassLoader(), + ClasspathHelper.staticClassLoader() + ) + ) + ).getSubTypesOf(classOf[QueryBundle]).asScala.toList - /** Return all queries inside `bundle`. - */ - def queriesInBundle[T <: QueryBundle](bundle: Class[T]): List[Query] = { - val instance = bundle.getField("MODULE$").get(null) - queryCreatorsInBundle(bundle).map { case (method, args) => - val query = method.invoke(instance, args: _*).asInstanceOf[Query] - val bundleNamespace = bundle.getPackageName - // the namespace currently looks like `io.appthreat.scanners.c.CopyLoops` - val namespaceParts = bundleNamespace.split('.') - val language = - if (bundleNamespace.startsWith("io.appthreat.ocular.scanners")) { - namespaceParts(4) - } else if (namespaceParts.length > 3) { - namespaceParts(3) - } else { - "" + /** Determine queries across all bundles + */ + def allQueries: List[Query] = + allBundles.flatMap { bundle => + queriesInBundle(bundle) } - query.copy(language = language) - } - } - /** Obtain all (method, args) pairs from bundle, making it possible to override default args before creating the - * query. - */ - def queryCreatorsInBundle[T <: QueryBundle](bundle: Class[T]): List[(Method, List[Any])] = { - val methods = bundle.getMethods.filter(_.getAnnotations.exists(_.isInstanceOf[q])).toList - methods.map { method => - val args = defaultArgs(method, bundle) - (method, args) - } - } + /** Return all queries inside `bundle`. + */ + def queriesInBundle[T <: QueryBundle](bundle: Class[T]): List[Query] = + val instance = bundle.getField("MODULE$").get(null) + queryCreatorsInBundle(bundle).map { case (method, args) => + val query = method.invoke(instance, args*).asInstanceOf[Query] + val bundleNamespace = bundle.getPackageName + // the namespace currently looks like `io.appthreat.scanners.c.CopyLoops` + val namespaceParts = bundleNamespace.split('.') + val language = + if bundleNamespace.startsWith("io.appthreat.ocular.scanners") then + namespaceParts(4) + else if namespaceParts.length > 3 then + namespaceParts(3) + else + "" + query.copy(language = language) + } - private def defaultArgs[T <: QueryBundle](method: Method, bundle: Class[T]): List[Any] = { - method.getParameters.zipWithIndex.map { case (parameter, index) => - defaultArgumentProvider.defaultArgument(method, bundle, parameter, index) - }.toList - } + /** Obtain all (method, args) pairs from bundle, making it possible to override default args + * before creating the query. + */ + def queryCreatorsInBundle[T <: QueryBundle](bundle: Class[T]): List[(Method, List[Any])] = + val methods = bundle.getMethods.filter(_.getAnnotations.exists(_.isInstanceOf[q])).toList + methods.map { method => + val args = defaultArgs(method, bundle) + (method, args) + } -} + private def defaultArgs[T <: QueryBundle](method: Method, bundle: Class[T]): List[Any] = + method.getParameters.zipWithIndex.map { case (parameter, index) => + defaultArgumentProvider.defaultArgument(method, bundle, parameter, index) + }.toList +end QueryDatabase -/** Joern and Ocular require different implicits to be present, and when we encounter these implicits as parameters in a - * query that we invoke via reflection, we need to obtain these implicits from somewhere. +/** Joern and Ocular require different implicits to be present, and when we encounter these + * implicits as parameters in a query that we invoke via reflection, we need to obtain these + * implicits from somewhere. * * We achieve this by implementing a `DefaultArgumentProvider` for Ocular, and one for Joern. */ -class DefaultArgumentProvider { +class DefaultArgumentProvider: - def typeSpecificDefaultArg(@unused argTypeFullName: String): Option[Any] = { - None - } + def typeSpecificDefaultArg(@unused argTypeFullName: String): Option[Any] = + None - final def defaultArgument(method: Method, bundle: Class[_], parameter: Parameter, i: Int): Any = { - val instance = bundle.getField("MODULE$").get(null) - val defaultArgOption = typeSpecificDefaultArg(parameter.getType.getTypeName) - defaultArgOption.getOrElse { - val defaultMethodName = s"${method.getName}$$default$$${i + 1}" - try { - val defaultMethod = bundle.getDeclaredMethod(defaultMethodName) - val defaultValue = defaultMethod.invoke(instance) - defaultValue - } catch { - case e: NoSuchMethodException => - throw new RuntimeException( - s"No default value found for parameter `${parameter.toString}` of query creator method `$method` " - ) - } - } - } - -} + final def defaultArgument(method: Method, bundle: Class[?], parameter: Parameter, i: Int): Any = + val instance = bundle.getField("MODULE$").get(null) + val defaultArgOption = typeSpecificDefaultArg(parameter.getType.getTypeName) + defaultArgOption.getOrElse { + val defaultMethodName = s"${method.getName}$$default$$${i + 1}" + try + val defaultMethod = bundle.getDeclaredMethod(defaultMethodName) + val defaultValue = defaultMethod.invoke(instance) + defaultValue + catch + case e: NoSuchMethodException => + throw new RuntimeException( + s"No default value found for parameter `${parameter.toString}` of query creator method `$method` " + ) + } +end DefaultArgumentProvider diff --git a/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala b/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala index 99a44586..716d23e5 100644 --- a/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala +++ b/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala @@ -6,17 +6,15 @@ import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import scala.quoted.{Expr, Quotes} -object QueryMacros { +object QueryMacros: - inline def withStrRep(inline traversal: Cpg => Iterator[_ <: StoredNode]): TraversalWithStrRep = - ${ withStrRepImpl('{ traversal }) } + inline def withStrRep(inline traversal: Cpg => Iterator[? <: StoredNode]): TraversalWithStrRep = + ${ withStrRepImpl('{ traversal }) } - private def withStrRepImpl( - travExpr: Expr[Cpg => Iterator[_ <: StoredNode]] - )(using quotes: Quotes): Expr[TraversalWithStrRep] = { - import quotes.reflect._ - val pos = travExpr.asTerm.pos - val code = Position(pos.sourceFile, pos.start, pos.end).sourceCode.getOrElse("N/A") - '{ TraversalWithStrRep(${ travExpr }, ${ Expr(code) }) } - } -} + private def withStrRepImpl( + travExpr: Expr[Cpg => Iterator[? <: StoredNode]] + )(using quotes: Quotes): Expr[TraversalWithStrRep] = + import quotes.reflect.* + val pos = travExpr.asTerm.pos + val code = Position(pos.sourceFile, pos.start, pos.end).sourceCode.getOrElse("N/A") + '{ TraversalWithStrRep(${ travExpr }, ${ Expr(code) }) } diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala index d292cd68..9174b9a1 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala @@ -12,23 +12,19 @@ import io.appthreat.x2cpg.X2CpgFrontend import scala.util.Try -class C2Cpg extends X2CpgFrontend[Config] { +class C2Cpg extends X2CpgFrontend[Config]: - private val report: Report = new Report() + private val report: Report = new Report() - def createCpg(config: Config): Try[Cpg] = { - withNewEmptyCpg(config.outputPath, config) { (cpg, config) => - new MetaDataPass(cpg, Languages.NEWC, config.inputPath).createAndApply() - new AstCreationPass(cpg, config, report).createAndApply() - TypeNodePass.withRegisteredTypes(CGlobal.typesSeen(), cpg).createAndApply() - new TypeDeclNodePass(cpg)(config.schemaValidation).createAndApply() - report.print() - } - } + def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config) { (cpg, config) => + new MetaDataPass(cpg, Languages.NEWC, config.inputPath).createAndApply() + new AstCreationPass(cpg, config, report).createAndApply() + TypeNodePass.withRegisteredTypes(CGlobal.typesSeen(), cpg).createAndApply() + new TypeDeclNodePass(cpg)(config.schemaValidation).createAndApply() + report.print() + } - def printIfDefsOnly(config: Config): Unit = { - val stmts = new PreprocessorPass(config).run().mkString(",") - println(stmts) - } - -} + def printIfDefsOnly(config: Config): Unit = + val stmts = new PreprocessorPass(config).run().mkString(",") + println(stmts) diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala index 425c349d..b3440241 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala @@ -1,6 +1,6 @@ package io.appthreat.c2cpg -import Frontend._ +import Frontend.* import io.appthreat.x2cpg.{X2CpgConfig, X2CpgMain} import org.slf4j.LoggerFactory import scopt.OParser @@ -20,128 +20,114 @@ final case class Config( includeFunctionBodies: Boolean = false, includeImageLocations: Boolean = false, useProjectIndex: Boolean = true -) extends X2CpgConfig[Config] { - def withIncludeFiles(includeFiles: Set[String]): Config = { - this.copy(includeFiles = includeFiles).withInheritedFields(this) - } - def withIncludePaths(includePaths: Set[String]): Config = { - this.copy(includePaths = includePaths).withInheritedFields(this) - } - - def withMacroFiles(macroFiles: Set[String]): Config = { - this.copy(macroFiles = macroFiles).withInheritedFields(this) - } - def withDefines(defines: Set[String]): Config = { - this.copy(defines = defines).withInheritedFields(this) - } - - def withIncludeComments(value: Boolean): Config = { - this.copy(includeComments = value).withInheritedFields(this) - } - - def withLogProblems(value: Boolean): Config = { - this.copy(logProblems = value).withInheritedFields(this) - } - - def withLogPreprocessor(value: Boolean): Config = { - this.copy(logPreprocessor = value).withInheritedFields(this) - } - - def withPrintIfDefsOnly(value: Boolean): Config = { - this.copy(printIfDefsOnly = value).withInheritedFields(this) - } - - def withIncludePathsAutoDiscovery(value: Boolean): Config = { - this.copy(includePathsAutoDiscovery = value).withInheritedFields(this) - } - - def withFunctionBodies(value: Boolean): Config = { - this.copy(includeFunctionBodies = value).withInheritedFields(this) - } - - def withImageLocations(value: Boolean): Config = { - this.copy(includeImageLocations = value).withInheritedFields(this) - } - - def withProjectIndexes(value: Boolean): Config = { - this.copy(useProjectIndex = value).withInheritedFields(this) - } -} - -private object Frontend { - implicit val defaultConfig: Config = Config() - - val cmdLineParser: OParser[Unit, Config] = { - val builder = OParser.builder[Config] - import builder._ - OParser.sequence( - programName(classOf[C2Cpg].getSimpleName), - opt[Unit]("include-comments") - .text(s"includes all comments into the CPG") - .action((_, c) => c.withIncludeComments(true)), - opt[Unit]("log-problems") - .text(s"enables logging of all parse problems while generating the CPG") - .action((_, c) => c.withLogProblems(true)), - opt[Unit]("log-preprocessor") - .text(s"enables logging of all preprocessor statements while generating the CPG") - .action((_, c) => c.withLogPreprocessor(true)), - opt[Unit]("print-ifdef-only") - .text(s"prints a comma-separated list of all preprocessor ifdef and if statements; does not create a CPG") - .action((_, c) => c.withPrintIfDefsOnly(true)), - opt[String]("include") - .unbounded() - .text("header include paths") - .action((incl, c) => c.withIncludePaths(c.includePaths + incl)), - opt[String]("include-files") - .unbounded() - .text("header include files") - .action((inclf, c) => c.withIncludeFiles(c.includeFiles + inclf)), - opt[String]("macro-files") - .unbounded() - .text("macro files") - .action((macrof, c) => c.withMacroFiles(c.macroFiles + macrof)), - opt[Unit]("no-include-auto-discovery") - .text("disables auto discovery of system header include paths") - .hidden(), - opt[Unit]("with-include-auto-discovery") - .text("enables auto discovery of system header include paths") - .action((_, c) => c.withIncludePathsAutoDiscovery(true)), - opt[Unit]("with-function-bodies") - .text("instructs the parser to parse function and method bodies.") - .action((_, c) => c.withFunctionBodies(true)), - opt[Unit]("with-image-locations") - .text( - "allows the parser to create image-locations. An image location explains how a name made it into the translation unit. Eg: via macro expansion or preprocessor." +) extends X2CpgConfig[Config]: + def withIncludeFiles(includeFiles: Set[String]): Config = + this.copy(includeFiles = includeFiles).withInheritedFields(this) + def withIncludePaths(includePaths: Set[String]): Config = + this.copy(includePaths = includePaths).withInheritedFields(this) + + def withMacroFiles(macroFiles: Set[String]): Config = + this.copy(macroFiles = macroFiles).withInheritedFields(this) + def withDefines(defines: Set[String]): Config = + this.copy(defines = defines).withInheritedFields(this) + + def withIncludeComments(value: Boolean): Config = + this.copy(includeComments = value).withInheritedFields(this) + + def withLogProblems(value: Boolean): Config = + this.copy(logProblems = value).withInheritedFields(this) + + def withLogPreprocessor(value: Boolean): Config = + this.copy(logPreprocessor = value).withInheritedFields(this) + + def withPrintIfDefsOnly(value: Boolean): Config = + this.copy(printIfDefsOnly = value).withInheritedFields(this) + + def withIncludePathsAutoDiscovery(value: Boolean): Config = + this.copy(includePathsAutoDiscovery = value).withInheritedFields(this) + + def withFunctionBodies(value: Boolean): Config = + this.copy(includeFunctionBodies = value).withInheritedFields(this) + + def withImageLocations(value: Boolean): Config = + this.copy(includeImageLocations = value).withInheritedFields(this) + + def withProjectIndexes(value: Boolean): Config = + this.copy(useProjectIndex = value).withInheritedFields(this) +end Config + +private object Frontend: + implicit val defaultConfig: Config = Config() + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName(classOf[C2Cpg].getSimpleName), + opt[Unit]("include-comments") + .text(s"includes all comments into the CPG") + .action((_, c) => c.withIncludeComments(true)), + opt[Unit]("log-problems") + .text(s"enables logging of all parse problems while generating the CPG") + .action((_, c) => c.withLogProblems(true)), + opt[Unit]("log-preprocessor") + .text(s"enables logging of all preprocessor statements while generating the CPG") + .action((_, c) => c.withLogPreprocessor(true)), + opt[Unit]("print-ifdef-only") + .text( + s"prints a comma-separated list of all preprocessor ifdef and if statements; does not create a CPG" + ) + .action((_, c) => c.withPrintIfDefsOnly(true)), + opt[String]("include") + .unbounded() + .text("header include paths") + .action((incl, c) => c.withIncludePaths(c.includePaths + incl)), + opt[String]("include-files") + .unbounded() + .text("header include files") + .action((inclf, c) => c.withIncludeFiles(c.includeFiles + inclf)), + opt[String]("macro-files") + .unbounded() + .text("macro files") + .action((macrof, c) => c.withMacroFiles(c.macroFiles + macrof)), + opt[Unit]("no-include-auto-discovery") + .text("disables auto discovery of system header include paths") + .hidden(), + opt[Unit]("with-include-auto-discovery") + .text("enables auto discovery of system header include paths") + .action((_, c) => c.withIncludePathsAutoDiscovery(true)), + opt[Unit]("with-function-bodies") + .text("instructs the parser to parse function and method bodies.") + .action((_, c) => c.withFunctionBodies(true)), + opt[Unit]("with-image-locations") + .text( + "allows the parser to create image-locations. An image location explains how a name made it into the translation unit. Eg: via macro expansion or preprocessor." + ) + .action((_, c) => c.withImageLocations(true)), + opt[Unit]("with-project-index") + .text( + "performance optimization, allows the parser to use an existing eclipse project(s) index(es)." + ) + .action((_, c) => c.withProjectIndexes(true)), + opt[String]("define") + .unbounded() + .text("define a name") + .action((d, c) => c.withDefines(c.defines + d)) ) - .action((_, c) => c.withImageLocations(true)), - opt[Unit]("with-project-index") - .text("performance optimization, allows the parser to use an existing eclipse project(s) index(es).") - .action((_, c) => c.withProjectIndexes(true)), - opt[String]("define") - .unbounded() - .text("define a name") - .action((d, c) => c.withDefines(c.defines + d)) - ) - } - -} - -object Main extends X2CpgMain(cmdLineParser, new C2Cpg()) { - - private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) - - def run(config: Config, c2cpg: C2Cpg): Unit = { - if (config.printIfDefsOnly) { - try { - c2cpg.printIfDefsOnly(config) - } catch { - case NonFatal(ex) => - logger.debug("Failed to print preprocessor statements.", ex) - throw ex - } - } else { - c2cpg.run(config) - } - } - -} + end cmdLineParser +end Frontend + +object Main extends X2CpgMain(cmdLineParser, new C2Cpg()): + + private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) + + def run(config: Config, c2cpg: C2Cpg): Unit = + if config.printIfDefsOnly then + try + c2cpg.printIfDefsOnly(config) + catch + case NonFatal(ex) => + logger.debug("Failed to print preprocessor statements.", ex) + throw ex + else + c2cpg.run(config) diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala index d86b21b5..1617bb76 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala @@ -3,7 +3,12 @@ package io.appthreat.c2cpg.astcreation import io.appthreat.c2cpg.Config import io.appthreat.x2cpg.datastructures.Scope import io.appthreat.x2cpg.datastructures.Stack.* -import io.appthreat.x2cpg.{Ast, AstCreatorBase, ValidationMode, AstNodeBuilder as X2CpgAstNodeBuilder} +import io.appthreat.x2cpg.{ + Ast, + AstCreatorBase, + ValidationMode, + AstNodeBuilder as X2CpgAstNodeBuilder +} import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal @@ -31,60 +36,84 @@ class AstCreator( with AstNodeBuilder with AstCreatorHelper with MacroHandler - with X2CpgAstNodeBuilder[IASTNode, AstCreator] { + with X2CpgAstNodeBuilder[IASTNode, AstCreator]: - protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) + protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) - protected val scope: Scope[String, (NewNode, String), NewNode] = new Scope() + protected val scope: Scope[String, (NewNode, String), NewNode] = new Scope() - protected val usingDeclarationMappings: mutable.Map[String, String] = mutable.HashMap.empty + protected val usingDeclarationMappings: mutable.Map[String, String] = mutable.HashMap.empty - // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST - // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. - // To achieve this we need this extra stack. - protected val methodAstParentStack: Stack[NewNode] = new Stack() + // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST + // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. + // To achieve this we need this extra stack. + protected val methodAstParentStack: Stack[NewNode] = new Stack() - def createAst(): DiffGraphBuilder = { - val ast = astForTranslationUnit(cdtAst) - Ast.storeInDiffGraph(ast, diffGraph) - diffGraph - } + def createAst(): DiffGraphBuilder = + val ast = astForTranslationUnit(cdtAst) + Ast.storeInDiffGraph(ast, diffGraph) + diffGraph - private def astForTranslationUnit(iASTTranslationUnit: IASTTranslationUnit): Ast = { - val namespaceBlock = globalNamespaceBlock() - methodAstParentStack.push(namespaceBlock) - val translationUnitAst = - astInFakeMethod(namespaceBlock.fullName, fileName(iASTTranslationUnit), iASTTranslationUnit) - val depsAndImportsAsts = astsForDependenciesAndImports(iASTTranslationUnit) - val commentsAsts = astsForComments(iASTTranslationUnit) - val childrenAsts = depsAndImportsAsts ++ Seq(translationUnitAst) ++ commentsAsts - setArgumentIndices(childrenAsts) - Ast(namespaceBlock).withChildren(childrenAsts) - } + private def astForTranslationUnit(iASTTranslationUnit: IASTTranslationUnit): Ast = + val namespaceBlock = globalNamespaceBlock() + methodAstParentStack.push(namespaceBlock) + val translationUnitAst = + astInFakeMethod( + namespaceBlock.fullName, + fileName(iASTTranslationUnit), + iASTTranslationUnit + ) + val depsAndImportsAsts = astsForDependenciesAndImports(iASTTranslationUnit) + val commentsAsts = astsForComments(iASTTranslationUnit) + val childrenAsts = depsAndImportsAsts ++ Seq(translationUnitAst) ++ commentsAsts + setArgumentIndices(childrenAsts) + Ast(namespaceBlock).withChildren(childrenAsts) - /** Creates an AST of all declarations found in the translation unit - wrapped in a fake method. - */ - private def astInFakeMethod(fullName: String, path: String, iASTTranslationUnit: IASTTranslationUnit): Ast = { - val allDecls = iASTTranslationUnit.getDeclarations.toList.filterNot(isIncludedNode) - val name = NamespaceTraversal.globalNamespaceName + /** Creates an AST of all declarations found in the translation unit - wrapped in a fake method. + */ + private def astInFakeMethod( + fullName: String, + path: String, + iASTTranslationUnit: IASTTranslationUnit + ): Ast = + val allDecls = iASTTranslationUnit.getDeclarations.toList.filterNot(isIncludedNode) + val name = NamespaceTraversal.globalNamespaceName - val fakeGlobalTypeDecl = - typeDeclNode(iASTTranslationUnit, name, fullName, filename, name, NodeTypes.NAMESPACE_BLOCK, fullName) - methodAstParentStack.push(fakeGlobalTypeDecl) + val fakeGlobalTypeDecl = + typeDeclNode( + iASTTranslationUnit, + name, + fullName, + filename, + name, + NodeTypes.NAMESPACE_BLOCK, + fullName + ) + methodAstParentStack.push(fakeGlobalTypeDecl) - val fakeGlobalMethod = - methodNode(iASTTranslationUnit, name, name, fullName, None, path, Option(NodeTypes.TYPE_DECL), Option(fullName)) - methodAstParentStack.push(fakeGlobalMethod) - scope.pushNewScope(fakeGlobalMethod) + val fakeGlobalMethod = + methodNode( + iASTTranslationUnit, + name, + name, + fullName, + None, + path, + Option(NodeTypes.TYPE_DECL), + Option(fullName) + ) + methodAstParentStack.push(fakeGlobalMethod) + scope.pushNewScope(fakeGlobalMethod) - val blockNode_ = blockNode(iASTTranslationUnit, Defines.empty, registerType(Defines.anyTypeName)) + val blockNode_ = + blockNode(iASTTranslationUnit, Defines.empty, registerType(Defines.anyTypeName)) - val declsAsts = allDecls.flatMap(astsForDeclaration) - setArgumentIndices(declsAsts) + val declsAsts = allDecls.flatMap(astsForDeclaration) + setArgumentIndices(declsAsts) - val methodReturn = newMethodReturnNode(iASTTranslationUnit, Defines.anyTypeName) - Ast(fakeGlobalTypeDecl).withChild( - methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn) - ) - } -} + val methodReturn = newMethodReturnNode(iASTTranslationUnit, Defines.anyTypeName) + Ast(fakeGlobalTypeDecl).withChild( + methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn) + ) + end astInFakeMethod +end AstCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala index 9ba850cd..c4075237 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala @@ -9,7 +9,11 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.utils.IOUtils import org.apache.commons.lang.StringUtils import org.eclipse.cdt.core.dom.ast.* -import org.eclipse.cdt.core.dom.ast.c.{ICASTArrayDesignator, ICASTDesignatedInitializer, ICASTFieldDesignator} +import org.eclipse.cdt.core.dom.ast.c.{ + ICASTArrayDesignator, + ICASTDesignatedInitializer, + ICASTFieldDesignator +} import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.c.ICASTKnRFunctionDeclarator import org.eclipse.cdt.internal.core.dom.parser.c.CASTArrayRangeDesignator @@ -24,546 +28,540 @@ import java.nio.file.{Path, Paths} import scala.annotation.nowarn import scala.collection.mutable -object AstCreatorHelper { +object AstCreatorHelper: - // maximum length of code fields in number of characters - private val MaxCodeLength: Int = 1000 - private val MinCodeLength: Int = 50 + // maximum length of code fields in number of characters + private val MaxCodeLength: Int = 1000 + private val MinCodeLength: Int = 50 - implicit class OptionSafeAst(val ast: Ast) extends AnyVal { - def withArgEdge(src: NewNode, dst: Option[NewNode]): Ast = dst match { - case Some(value) => ast.withArgEdge(src, value) - case None => ast - } - } -} + implicit class OptionSafeAst(val ast: Ast) extends AnyVal: + def withArgEdge(src: NewNode, dst: Option[NewNode]): Ast = dst match + case Some(value) => ast.withArgEdge(src, value) + case None => ast -trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - import AstCreatorHelper._ - - private var usedVariablePostfix: Int = 0 - - private val IncludeKeyword = "include" - - protected def isIncludedNode(node: IASTNode): Boolean = fileName(node) != filename - - protected def uniqueName(target: String, name: String, fullName: String): (String, String) = { - if (name.isEmpty && (fullName.isEmpty || fullName.endsWith("."))) { - val name = s"anonymous_${target}_$usedVariablePostfix" - val resultingFullName = s"$fullName$name" - usedVariablePostfix = usedVariablePostfix + 1 - (name, resultingFullName) - } else { - (name, fullName) - } - } - - private def fileOffsetTable(node: IASTNode): Array[Int] = { - val path = SourceFiles.toAbsolutePath(fileName(node), config.inputPath) - file2OffsetTable.computeIfAbsent(path, _ => genFileOffsetTable(Paths.get(path))) - } - - private def genFileOffsetTable(absolutePath: Path): Array[Int] = { - val asCharArray = IOUtils.readLinesInFile(absolutePath).mkString("\n").toCharArray - val offsets = mutable.ArrayBuffer.empty[Int] - - for (i <- Range(0, asCharArray.length)) { - if (asCharArray(i) == '\n') { - offsets.append(i + 1) - } - } - offsets.toArray - } - - private def nullSafeFileLocation(node: IASTNode): Option[IASTFileLocation] = - Option(cdtAst.flattenLocationsToFile(node.getNodeLocations)).map(_.asFileLocation()) - - private def nullSafeFileLocationLast(node: IASTNode): Option[IASTFileLocation] = - Option(cdtAst.flattenLocationsToFile(node.getNodeLocations.lastOption.toArray)).map(_.asFileLocation()) - - protected def fileName(node: IASTNode): String = { - val path = nullSafeFileLocation(node).map(_.getFileName).getOrElse(filename) - SourceFiles.toRelativePath(path, config.inputPath) - } - - protected def line(node: IASTNode): Option[Integer] = { - nullSafeFileLocation(node).map(_.getStartingLineNumber) - } - - protected def lineEnd(node: IASTNode): Option[Integer] = { - nullSafeFileLocationLast(node).map(_.getEndingLineNumber) - } - - private def offsetToColumn(node: IASTNode, offset: Int): Int = { - val table = fileOffsetTable(node) - val index = java.util.Arrays.binarySearch(table, offset) - val tableIndex = if (index < 0) -(index + 1) else index + 1 - val lineStartOffset = if (tableIndex == 0) { - 0 - } else { - table(tableIndex - 1) - } - val column = offset - lineStartOffset + 1 - column - } - - protected def column(node: IASTNode): Option[Integer] = { - val loc = nullSafeFileLocation(node) - loc.map { x => - offsetToColumn(node, x.getNodeOffset) - } - } - - protected def columnEnd(node: IASTNode): Option[Integer] = { - val loc = nullSafeFileLocation(node) - loc.map { x => - offsetToColumn(node, x.getNodeOffset + x.getNodeLength - 1) - } - } - - protected def registerType(typeName: String): String = { - val fixedTypeName = fixQualifiedName(StringUtils.normalizeSpace(typeName)) - CGlobal.usedTypes.putIfAbsent(fixedTypeName, true) - fixedTypeName - } - - // Sadly, there is no predefined List / Enum of this within Eclipse CDT: - private val reservedTypeKeywords: List[String] = - List( - "const", - "static", - "volatile", - "restrict", - "extern", - "typedef", - "inline", - "constexpr", - "auto", - "virtual", - "enum", - "struct", - "interface", - "class", - "naked", - "export", - "module", - "import" - ) - - protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = { - val tpe = - if (stripKeywords) { - reservedTypeKeywords.foldLeft(rawType) { (cur, repl) => - if (cur.contains(s"$repl ")) { - dereferenceTypeFullName(cur.replace(s"$repl ", "")) - } else { - cur - } - } - } else { - rawType - } - StringUtils.normalizeSpace(tpe) match { - case "" => Defines.anyTypeName - case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => Defines.anyTypeName - case t if t.contains(" ->") && t.contains("}::") => - fixQualifiedName(t.substring(t.indexOf("}::") + 3, t.indexOf(" ->"))) - case t if t.contains(" ->") => - fixQualifiedName(t.substring(0, t.indexOf(" ->"))) - case t if t.contains("( ") => - fixQualifiedName(t.substring(0, t.indexOf("( "))) - case t if t.contains("?") => Defines.anyTypeName - case t if t.contains("#") => Defines.anyTypeName - case t if t.contains("{") && t.contains("}") => - val anonType = - s"${uniqueName("type", "", "")._1}${t.substring(0, t.indexOf("{"))}${t.substring(t.indexOf("}") + 1)}" - anonType.replace(" ", "") - case t if t.startsWith("[") && t.endsWith("]") => Defines.anyTypeName - case t if t.contains(Defines.qualifiedNameSeparator) => - fixQualifiedName(t).split(".").lastOption.getOrElse(Defines.anyTypeName) - case t if t.startsWith("unsigned ") => "unsigned " + t.substring(9).replace(" ", "") - case t if t.contains("[") && t.contains("]") => t.replace(" ", "") - case t if t.contains("*") => t.replace(" ", "") - case someType => someType - } - } - - @nowarn - protected def typeFor(node: IASTNode, stripKeywords: Boolean = true): String = { - import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature - node match { - case f: CPPASTFieldReference => - f.getFieldOwner.getEvaluation match { - case evaluation: EvalBinding => cleanType(evaluation.getType.toString, stripKeywords) - case _ => cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) +trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + import AstCreatorHelper.* + + private var usedVariablePostfix: Int = 0 + + private val IncludeKeyword = "include" + + protected def isIncludedNode(node: IASTNode): Boolean = fileName(node) != filename + + protected def uniqueName(target: String, name: String, fullName: String): (String, String) = + if name.isEmpty && (fullName.isEmpty || fullName.endsWith(".")) then + val name = s"anonymous_${target}_$usedVariablePostfix" + val resultingFullName = s"$fullName$name" + usedVariablePostfix = usedVariablePostfix + 1 + (name, resultingFullName) + else + (name, fullName) + + private def fileOffsetTable(node: IASTNode): Array[Int] = + val path = SourceFiles.toAbsolutePath(fileName(node), config.inputPath) + file2OffsetTable.computeIfAbsent(path, _ => genFileOffsetTable(Paths.get(path))) + + private def genFileOffsetTable(absolutePath: Path): Array[Int] = + val asCharArray = IOUtils.readLinesInFile(absolutePath).mkString("\n").toCharArray + val offsets = mutable.ArrayBuffer.empty[Int] + + for i <- Range(0, asCharArray.length) do + if asCharArray(i) == '\n' then + offsets.append(i + 1) + offsets.toArray + + private def nullSafeFileLocation(node: IASTNode): Option[IASTFileLocation] = + Option(cdtAst.flattenLocationsToFile(node.getNodeLocations)).map(_.asFileLocation()) + + private def nullSafeFileLocationLast(node: IASTNode): Option[IASTFileLocation] = + Option(cdtAst.flattenLocationsToFile(node.getNodeLocations.lastOption.toArray)).map( + _.asFileLocation() + ) + + protected def fileName(node: IASTNode): String = + val path = nullSafeFileLocation(node).map(_.getFileName).getOrElse(filename) + SourceFiles.toRelativePath(path, config.inputPath) + + protected def line(node: IASTNode): Option[Integer] = + nullSafeFileLocation(node).map(_.getStartingLineNumber) + + protected def lineEnd(node: IASTNode): Option[Integer] = + nullSafeFileLocationLast(node).map(_.getEndingLineNumber) + + private def offsetToColumn(node: IASTNode, offset: Int): Int = + val table = fileOffsetTable(node) + val index = java.util.Arrays.binarySearch(table, offset) + val tableIndex = if index < 0 then -(index + 1) else index + 1 + val lineStartOffset = if tableIndex == 0 then + 0 + else + table(tableIndex - 1) + val column = offset - lineStartOffset + 1 + column + + protected def column(node: IASTNode): Option[Integer] = + val loc = nullSafeFileLocation(node) + loc.map { x => + offsetToColumn(node, x.getNodeOffset) } - case f: IASTFieldReference => - cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).startsWith("? ") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val arr = ASTTypeUtil.getNodeType(a).replace("? ", "") - s"$tpe$arr" - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains("} ") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val nodeType = ASTTypeUtil.getNodeType(node) - val arr = nodeType.substring(nodeType.indexOf("["), nodeType.indexOf("]") + 1) - s"$tpe$arr" - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains(" [") => - cleanType(ASTTypeUtil.getNodeType(node)) - case s: CPPASTIdExpression => - s.getEvaluation match { - case evaluation: EvalMemberAccess => cleanType(evaluation.getOwnerType.toString, stripKeywords) - case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) + + protected def columnEnd(node: IASTNode): Option[Integer] = + val loc = nullSafeFileLocation(node) + loc.map { x => + offsetToColumn(node, x.getNodeOffset + x.getNodeLength - 1) } - case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => - cleanType(ASTTypeUtil.getNodeType(node), stripKeywords) - case s: IASTNamedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTCompositeTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTEnumerationSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTElaboratedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case l: IASTLiteralExpression => - cleanType(ASTTypeUtil.getType(l.getExpressionType)) - case e: IASTExpression => - cleanType(ASTTypeUtil.getNodeType(e), stripKeywords) - case _ => - cleanType(getNodeSignature(node), stripKeywords) - } - } - - protected def shortenCode(code: String, length: Int = MaxCodeLength): String = - StringUtils.abbreviate(code, math.max(MinCodeLength, length)) - - private def notHandledText(node: IASTNode): String = - s"""Node '${node.getClass.getSimpleName}' not handled yet! + + protected def registerType(typeName: String): String = + val fixedTypeName = fixQualifiedName(StringUtils.normalizeSpace(typeName)) + CGlobal.usedTypes.putIfAbsent(fixedTypeName, true) + fixedTypeName + + // Sadly, there is no predefined List / Enum of this within Eclipse CDT: + private val reservedTypeKeywords: List[String] = + List( + "const", + "static", + "volatile", + "restrict", + "extern", + "typedef", + "inline", + "constexpr", + "auto", + "virtual", + "enum", + "struct", + "interface", + "class", + "naked", + "export", + "module", + "import" + ) + + protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = + val tpe = + if stripKeywords then + reservedTypeKeywords.foldLeft(rawType) { (cur, repl) => + if cur.contains(s"$repl ") then + dereferenceTypeFullName(cur.replace(s"$repl ", "")) + else + cur + } + else + rawType + StringUtils.normalizeSpace(tpe) match + case "" => Defines.anyTypeName + case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => + Defines.anyTypeName + case t if t.contains(" ->") && t.contains("}::") => + fixQualifiedName(t.substring(t.indexOf("}::") + 3, t.indexOf(" ->"))) + case t if t.contains(" ->") => + fixQualifiedName(t.substring(0, t.indexOf(" ->"))) + case t if t.contains("( ") => + fixQualifiedName(t.substring(0, t.indexOf("( "))) + case t if t.contains("?") => Defines.anyTypeName + case t if t.contains("#") => Defines.anyTypeName + case t if t.contains("{") && t.contains("}") => + val anonType = + s"${uniqueName("type", "", "")._1}${t.substring(0, t.indexOf("{"))}${t.substring(t.indexOf("}") + 1)}" + anonType.replace(" ", "") + case t if t.startsWith("[") && t.endsWith("]") => Defines.anyTypeName + case t if t.contains(Defines.qualifiedNameSeparator) => + fixQualifiedName(t).split(".").lastOption.getOrElse(Defines.anyTypeName) + case t if t.startsWith("unsigned ") => "unsigned " + t.substring(9).replace(" ", "") + case t if t.contains("[") && t.contains("]") => t.replace(" ", "") + case t if t.contains("*") => t.replace(" ", "") + case someType => someType + end match + end cleanType + + @nowarn + protected def typeFor(node: IASTNode, stripKeywords: Boolean = true): String = + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + node match + case f: CPPASTFieldReference => + f.getFieldOwner.getEvaluation match + case evaluation: EvalBinding => + cleanType(evaluation.getType.toString, stripKeywords) + case _ => cleanType( + ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), + stripKeywords + ) + case f: IASTFieldReference => + cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).startsWith("? ") => + val tpe = getNodeSignature(a).replace("[]", "").strip() + val arr = ASTTypeUtil.getNodeType(a).replace("? ", "") + s"$tpe$arr" + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains("} ") => + val tpe = getNodeSignature(a).replace("[]", "").strip() + val nodeType = ASTTypeUtil.getNodeType(node) + val arr = nodeType.substring(nodeType.indexOf("["), nodeType.indexOf("]") + 1) + s"$tpe$arr" + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains(" [") => + cleanType(ASTTypeUtil.getNodeType(node)) + case s: CPPASTIdExpression => + s.getEvaluation match + case evaluation: EvalMemberAccess => + cleanType(evaluation.getOwnerType.toString, stripKeywords) + case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) + case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => + cleanType(ASTTypeUtil.getNodeType(node), stripKeywords) + case s: IASTNamedTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTCompositeTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTEnumerationSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTElaboratedTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case l: IASTLiteralExpression => + cleanType(ASTTypeUtil.getType(l.getExpressionType)) + case e: IASTExpression => + cleanType(ASTTypeUtil.getNodeType(e), stripKeywords) + case _ => + cleanType(getNodeSignature(node), stripKeywords) + end match + end typeFor + + protected def shortenCode(code: String, length: Int = MaxCodeLength): String = + StringUtils.abbreviate(code, math.max(MinCodeLength, length)) + + private def notHandledText(node: IASTNode): String = + s"""Node '${node.getClass.getSimpleName}' not handled yet! | Code: '${node.getRawSignature}' | File: '$filename' | Line: ${line(node).getOrElse(-1)} | """.stripMargin - protected def notHandledYet(node: IASTNode): Ast = { - if (!node.isInstanceOf[IASTProblem] && !node.isInstanceOf[IASTProblemHolder]) { - val text = notHandledText(node) - logger.debug(text) - } - Ast(unknownNode(node, nodeSignature(node))) - } - - protected def nullSafeCode(node: IASTNode): String = { - Option(node).map(nodeSignature).getOrElse("") - } - - protected def nullSafeAst(node: IASTExpression, argIndex: Int): Ast = { - val r = nullSafeAst(node) - r.root match { - case Some(x: ExpressionNew) => - x.argumentIndex = argIndex - case _ => - } - r - } - - protected def nullSafeAst(node: IASTExpression): Ast = - Option(node).map(astForNode).getOrElse(Ast()) - - protected def nullSafeAst(node: IASTStatement, argIndex: Int = -1): Seq[Ast] = { - Option(node).map(astsForStatement(_, argIndex)).getOrElse(Seq.empty) - } - - protected def dereferenceTypeFullName(fullName: String): String = - fullName.replace("*", "") - - protected def fixQualifiedName(name: String): String = - name.stripPrefix(Defines.qualifiedNameSeparator).replace(Defines.qualifiedNameSeparator, ".") - - protected def isQualifiedName(name: String): Boolean = - name.startsWith(Defines.qualifiedNameSeparator) - - protected def lastNameOfQualifiedName(name: String): String = { - val cleanedName = if (name.contains("<") && name.contains(">")) { - name.substring(0, name.indexOf("<")) - } else { - name - } - cleanedName.split(Defines.qualifiedNameSeparator).lastOption.getOrElse(cleanedName) - } - - protected def fullName(node: IASTNode): String = { - val qualifiedName: String = node match { - case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => - val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] - evaluation.getBinding match { - case f: CPPFunction if f.getDeclarations != null => - usingDeclarationMappings.getOrElse( - fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), - f.getDeclarations.headOption.map(n => ASTStringUtil.getSimpleName(n.getName)).getOrElse(f.getName) - ) - case f: CPPFunction if f.getDefinition != null => - usingDeclarationMappings.getOrElse( - fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), - ASTStringUtil.getSimpleName(f.getDefinition.getName) - ) - case other => other.getName + protected def notHandledYet(node: IASTNode): Ast = + if !node.isInstanceOf[IASTProblem] && !node.isInstanceOf[IASTProblemHolder] then + val text = notHandledText(node) + logger.debug(text) + Ast(unknownNode(node, nodeSignature(node))) + + protected def nullSafeCode(node: IASTNode): String = + Option(node).map(nodeSignature).getOrElse("") + + protected def nullSafeAst(node: IASTExpression, argIndex: Int): Ast = + val r = nullSafeAst(node) + r.root match + case Some(x: ExpressionNew) => + x.argumentIndex = argIndex + case _ => + r + + protected def nullSafeAst(node: IASTExpression): Ast = + Option(node).map(astForNode).getOrElse(Ast()) + + protected def nullSafeAst(node: IASTStatement, argIndex: Int = -1): Seq[Ast] = + Option(node).map(astsForStatement(_, argIndex)).getOrElse(Seq.empty) + + protected def dereferenceTypeFullName(fullName: String): String = + fullName.replace("*", "") + + protected def fixQualifiedName(name: String): String = + name.stripPrefix(Defines.qualifiedNameSeparator).replace( + Defines.qualifiedNameSeparator, + "." + ) + + protected def isQualifiedName(name: String): Boolean = + name.startsWith(Defines.qualifiedNameSeparator) + + protected def lastNameOfQualifiedName(name: String): String = + val cleanedName = if name.contains("<") && name.contains(">") then + name.substring(0, name.indexOf("<")) + else + name + cleanedName.split(Defines.qualifiedNameSeparator).lastOption.getOrElse(cleanedName) + + protected def fullName(node: IASTNode): String = + val qualifiedName: String = node match + case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => + val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] + evaluation.getBinding match + case f: CPPFunction if f.getDeclarations != null => + usingDeclarationMappings.getOrElse( + fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), + f.getDeclarations.headOption.map(n => + ASTStringUtil.getSimpleName(n.getName) + ).getOrElse(f.getName) + ) + case f: CPPFunction if f.getDefinition != null => + usingDeclarationMappings.getOrElse( + fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), + ASTStringUtil.getSimpleName(f.getDefinition.getName) + ) + case other => other.getName + case alias: ICPPASTNamespaceAlias => alias.getMappingName.toString + case namespace: ICPPASTNamespaceDefinition + if ASTStringUtil.getSimpleName(namespace.getName).nonEmpty => + s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" + case namespace: ICPPASTNamespaceDefinition + if ASTStringUtil.getSimpleName(namespace.getName).isEmpty => + s"${fullName(namespace.getParent)}.${uniqueName("namespace", "", "")._1}" + case compType: IASTCompositeTypeSpecifier + if ASTStringUtil.getSimpleName(compType.getName).nonEmpty => + s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" + case compType: IASTCompositeTypeSpecifier + if ASTStringUtil.getSimpleName(compType.getName).isEmpty => + val name = compType.getParent match + case decl: IASTSimpleDeclaration => + decl.getDeclarators.headOption + .map(n => ASTStringUtil.getSimpleName(n.getName)) + .getOrElse(uniqueName("composite_type", "", "")._1) + case _ => uniqueName("composite_type", "", "")._1 + s"${fullName(compType.getParent)}.$name" + case enumSpecifier: IASTEnumerationSpecifier => + s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" + case f: ICPPASTLambdaExpression => + s"${fullName(f.getParent)}." + case f: IASTFunctionDeclarator + if ASTStringUtil.getSimpleName( + f.getName + ).isEmpty && f.getNestedDeclarator != null => + s"${fullName(f.getParent)}.${shortName(f.getNestedDeclarator)}" + case f: IASTFunctionDeclarator => + s"${fullName(f.getParent)}.${ASTStringUtil.getSimpleName(f.getName)}" + case f: IASTFunctionDefinition if f.getDeclarator != null => + s"${fullName(f.getParent)}.${ASTStringUtil.getQualifiedName(f.getDeclarator.getName)}" + case f: IASTFunctionDefinition => + s"${fullName(f.getParent)}.${shortName(f)}" + case e: IASTElaboratedTypeSpecifier => + s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" + case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) + case _: IASTTranslationUnit => "" + case u: IASTUnaryExpression => nodeSignature(u.getOperand) + case other if other.getParent != null => fullName(other.getParent) + case other if other != null => notHandledYet(other); "" + case null => "" + fixQualifiedName(qualifiedName).stripPrefix(".") + end fullName + + protected def shortName(node: IASTNode): String = + val name = node match + case d: IASTDeclarator + if ASTStringUtil.getSimpleName( + d.getName + ).isEmpty && d.getNestedDeclarator != null => + shortName(d.getNestedDeclarator) + case d: IASTDeclarator => ASTStringUtil.getSimpleName(d.getName) + case f: ICPPASTFunctionDefinition + if ASTStringUtil + .getSimpleName(f.getDeclarator.getName) + .isEmpty && f.getDeclarator.getNestedDeclarator != null => + shortName(f.getDeclarator.getNestedDeclarator) + case f: ICPPASTFunctionDefinition => + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(f.getDeclarator.getName)) + case f: IASTFunctionDefinition + if ASTStringUtil + .getSimpleName(f.getDeclarator.getName) + .isEmpty && f.getDeclarator.getNestedDeclarator != null => + shortName(f.getDeclarator.getNestedDeclarator) + case f: IASTFunctionDefinition => ASTStringUtil.getSimpleName(f.getDeclarator.getName) + case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => + val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] + evaluation.getBinding match + case f: CPPFunction if f.getDeclarations != null => + f.getDeclarations.headOption.map(n => + ASTStringUtil.getSimpleName(n.getName) + ).getOrElse(f.getName) + case f: CPPFunction if f.getDefinition != null => + ASTStringUtil.getSimpleName(f.getDefinition.getName) + case other => + other.getName + case d: IASTIdExpression => + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) + case u: IASTUnaryExpression => shortName(u.getOperand) + case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) + case s: IASTSimpleDeclSpecifier => s.getRawSignature + case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) + case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) + case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) + case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case other => notHandledYet(other); "" + name + end shortName + + private def pointersAsString( + spec: IASTDeclSpecifier, + parentDecl: IASTDeclarator, + stripKeywords: Boolean + ): String = + val tpe = typeFor(spec, stripKeywords) + val pointers = parentDecl.getPointerOperators + val arr = parentDecl match + case p: IASTArrayDeclarator => + p.getArrayModifiers.toList.map(_.getRawSignature).mkString + case _ => "" + if pointers.isEmpty then s"$tpe$arr" + else + val refs = + "*" * (pointers.length - pointers.count(_.isInstanceOf[ICPPASTReferenceOperator])) + s"$tpe$arr$refs".strip() + + protected def astsForDependenciesAndImports(iASTTranslationUnit: IASTTranslationUnit) + : Seq[Ast] = + val allIncludes = iASTTranslationUnit.getIncludeDirectives.toList.filterNot(isIncludedNode) + allIncludes.map { include => + val name = include.getName.toString + val _dependencyNode = newDependencyNode(name, name, IncludeKeyword) + val importNode = newImportNode(nodeSignature(include), name, name, include) + diffGraph.addNode(_dependencyNode) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + Ast(importNode) } - case alias: ICPPASTNamespaceAlias => alias.getMappingName.toString - case namespace: ICPPASTNamespaceDefinition if ASTStringUtil.getSimpleName(namespace.getName).nonEmpty => - s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" - case namespace: ICPPASTNamespaceDefinition if ASTStringUtil.getSimpleName(namespace.getName).isEmpty => - s"${fullName(namespace.getParent)}.${uniqueName("namespace", "", "")._1}" - case compType: IASTCompositeTypeSpecifier if ASTStringUtil.getSimpleName(compType.getName).nonEmpty => - s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" - case compType: IASTCompositeTypeSpecifier if ASTStringUtil.getSimpleName(compType.getName).isEmpty => - val name = compType.getParent match { - case decl: IASTSimpleDeclaration => - decl.getDeclarators.headOption - .map(n => ASTStringUtil.getSimpleName(n.getName)) - .getOrElse(uniqueName("composite_type", "", "")._1) - case _ => uniqueName("composite_type", "", "")._1 + + protected def astsForComments(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = + if config.includeComments then + iASTTranslationUnit.getComments.toList.filterNot(isIncludedNode).map(comment => + astForComment(comment) + ) + else + Seq.empty + + private def astForDecltypeSpecifier(decl: ICPPASTDecltypeSpecifier): Ast = + val op = ".typeOf" + val cpgUnary = callNode(decl, nodeSignature(decl), op, op, DispatchTypes.STATIC_DISPATCH) + val operand = nullSafeAst(decl.getDecltypeExpression) + callAst(cpgUnary, List(operand)) + + private def astForCASTDesignatedInitializer(d: ICASTDesignatedInitializer): Ast = + val node = blockNode(d, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val op = Operators.assignment + val calls = withIndex(d.getDesignators) { (des, o) => + val callNode_ = + callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(o) + val left = astForNode(des) + val right = astForNode(d.getOperand) + callAst(callNode_, List(left, right)) } - s"${fullName(compType.getParent)}.$name" - case enumSpecifier: IASTEnumerationSpecifier => - s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" - case f: ICPPASTLambdaExpression => - s"${fullName(f.getParent)}." - case f: IASTFunctionDeclarator - if ASTStringUtil.getSimpleName(f.getName).isEmpty && f.getNestedDeclarator != null => - s"${fullName(f.getParent)}.${shortName(f.getNestedDeclarator)}" - case f: IASTFunctionDeclarator => - s"${fullName(f.getParent)}.${ASTStringUtil.getSimpleName(f.getName)}" - case f: IASTFunctionDefinition if f.getDeclarator != null => - s"${fullName(f.getParent)}.${ASTStringUtil.getQualifiedName(f.getDeclarator.getName)}" - case f: IASTFunctionDefinition => - s"${fullName(f.getParent)}.${shortName(f)}" - case e: IASTElaboratedTypeSpecifier => - s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" - case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) - case _: IASTTranslationUnit => "" - case u: IASTUnaryExpression => nodeSignature(u.getOperand) - case other if other.getParent != null => fullName(other.getParent) - case other if other != null => notHandledYet(other); "" - case null => "" - } - fixQualifiedName(qualifiedName).stripPrefix(".") - } - - protected def shortName(node: IASTNode): String = { - val name = node match { - case d: IASTDeclarator if ASTStringUtil.getSimpleName(d.getName).isEmpty && d.getNestedDeclarator != null => - shortName(d.getNestedDeclarator) - case d: IASTDeclarator => ASTStringUtil.getSimpleName(d.getName) - case f: ICPPASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: ICPPASTFunctionDefinition => lastNameOfQualifiedName(ASTStringUtil.getSimpleName(f.getDeclarator.getName)) - case f: IASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: IASTFunctionDefinition => ASTStringUtil.getSimpleName(f.getDeclarator.getName) - case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => - val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] - evaluation.getBinding match { - case f: CPPFunction if f.getDeclarations != null => - f.getDeclarations.headOption.map(n => ASTStringUtil.getSimpleName(n.getName)).getOrElse(f.getName) - case f: CPPFunction if f.getDefinition != null => - ASTStringUtil.getSimpleName(f.getDefinition.getName) - case other => - other.getName + scope.popScope() + blockAst(node, calls.toList) + + private def astForCPPASTDesignatedInitializer(d: ICPPASTDesignatedInitializer): Ast = + val node = blockNode(d, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val op = Operators.assignment + val calls = withIndex(d.getDesignators) { (des, o) => + val callNode_ = + callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(o) + val left = astForNode(des) + val right = astForNode(d.getOperand) + callAst(callNode_, List(left, right)) } - case d: IASTIdExpression => lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) - case u: IASTUnaryExpression => shortName(u.getOperand) - case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) - case s: IASTSimpleDeclSpecifier => s.getRawSignature - case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) - case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) - case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) - case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case other => notHandledYet(other); "" - } - name - } - - private def pointersAsString(spec: IASTDeclSpecifier, parentDecl: IASTDeclarator, stripKeywords: Boolean): String = { - val tpe = typeFor(spec, stripKeywords) - val pointers = parentDecl.getPointerOperators - val arr = parentDecl match { - case p: IASTArrayDeclarator => p.getArrayModifiers.toList.map(_.getRawSignature).mkString - case _ => "" - } - if (pointers.isEmpty) { s"$tpe$arr" } - else { - val refs = "*" * (pointers.length - pointers.count(_.isInstanceOf[ICPPASTReferenceOperator])) - s"$tpe$arr$refs".strip() - } - } - - protected def astsForDependenciesAndImports(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = { - val allIncludes = iASTTranslationUnit.getIncludeDirectives.toList.filterNot(isIncludedNode) - allIncludes.map { include => - val name = include.getName.toString - val _dependencyNode = newDependencyNode(name, name, IncludeKeyword) - val importNode = newImportNode(nodeSignature(include), name, name, include) - diffGraph.addNode(_dependencyNode) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - Ast(importNode) - } - } - - protected def astsForComments(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = { - if (config.includeComments) { - iASTTranslationUnit.getComments.toList.filterNot(isIncludedNode).map(comment => astForComment(comment)) - } else { - Seq.empty - } - } - - private def astForDecltypeSpecifier(decl: ICPPASTDecltypeSpecifier): Ast = { - val op = ".typeOf" - val cpgUnary = callNode(decl, nodeSignature(decl), op, op, DispatchTypes.STATIC_DISPATCH) - val operand = nullSafeAst(decl.getDecltypeExpression) - callAst(cpgUnary, List(operand)) - } - - private def astForCASTDesignatedInitializer(d: ICASTDesignatedInitializer): Ast = { - val node = blockNode(d, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val op = Operators.assignment - val calls = withIndex(d.getDesignators) { (des, o) => - val callNode_ = - callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(o) - val left = astForNode(des) - val right = astForNode(d.getOperand) - callAst(callNode_, List(left, right)) - } - scope.popScope() - blockAst(node, calls.toList) - } - - private def astForCPPASTDesignatedInitializer(d: ICPPASTDesignatedInitializer): Ast = { - val node = blockNode(d, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val op = Operators.assignment - val calls = withIndex(d.getDesignators) { (des, o) => - val callNode_ = - callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(o) - val left = astForNode(des) - val right = astForNode(d.getOperand) - callAst(callNode_, List(left, right)) - } - scope.popScope() - blockAst(node, calls.toList) - } - - private def astForCPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): Ast = { - val name = ".constructorInitializer" - val callNode_ = - callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) - val args = c.getArguments.toList.map(a => astForNode(a)) - callAst(callNode_, args) - } - - private def astForCASTArrayRangeDesignator(des: CASTArrayRangeDesignator): Ast = { - val op = Operators.arrayInitializer - val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) - val floorAst = nullSafeAst(des.getRangeFloor) - val ceilingAst = nullSafeAst(des.getRangeCeiling) - callAst(callNode_, List(floorAst, ceilingAst)) - } - - private def astForCPPASTArrayRangeDesignator(des: CPPASTArrayRangeDesignator): Ast = { - val op = Operators.arrayInitializer - val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) - val floorAst = nullSafeAst(des.getRangeFloor) - val ceilingAst = nullSafeAst(des.getRangeCeiling) - callAst(callNode_, List(floorAst, ceilingAst)) - } - - protected def astForNode(node: IASTNode): Ast = { - if (config.includeFunctionBodies) astForNodeFull(node) else astForNodePartial(node) - } - - protected def astForNodeFull(node: IASTNode): Ast = { - node match { - case expr: IASTExpression => astForExpression(expr) - case name: IASTName => astForIdentifier(name) - case decl: IASTDeclSpecifier => astForIdentifier(decl) - case l: IASTInitializerList => astForInitializerList(l) - case c: ICPPASTConstructorInitializer => astForCPPASTConstructorInitializer(c) - case d: ICASTDesignatedInitializer => astForCASTDesignatedInitializer(d) - case d: ICPPASTDesignatedInitializer => astForCPPASTDesignatedInitializer(d) - case d: CASTArrayRangeDesignator => astForCASTArrayRangeDesignator(d) - case d: CPPASTArrayRangeDesignator => astForCPPASTArrayRangeDesignator(d) - case d: ICASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) - case d: ICPPASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) - case d: ICPPASTFieldDesignator => astForNode(d.getName) - case d: ICASTFieldDesignator => astForNode(d.getName) - case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) - case arrMod: IASTArrayModifier => astForArrayModifier(arrMod) - case _ => notHandledYet(node) - } - } - - protected def astForNodePartial(node: IASTNode): Ast = { - node match { - case expr: IASTExpression => astForExpression(expr) - case name: IASTName => astForIdentifier(name) - case decl: IASTDeclSpecifier => astForIdentifier(decl) - case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) - case _ => notHandledYet(node) - } - } - - protected def typeForDeclSpecifier(spec: IASTNode, stripKeywords: Boolean = true, index: Int = 0): String = { - val tpe = spec match { - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTFunctionDefinition] => - val parentDecl = s.getParent.asInstanceOf[IASTFunctionDefinition].getDeclarator - ASTStringUtil.getReturnTypeString(s, parentDecl) - case s: IASTSimpleDeclaration if s.getParent.isInstanceOf[ICASTKnRFunctionDeclarator] => - val decl = s.getDeclarators.toList(index) - pointersAsString(s.getDeclSpecifier, decl, stripKeywords) - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTSimpleDeclSpecifier => - ASTStringUtil.getReturnTypeString(s, null) - case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTCompositeTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTEnumerationSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTElaboratedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTElaboratedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTElaboratedTypeSpecifier => ASTStringUtil.getSignatureString(s, null) - // TODO: handle other types of IASTDeclSpecifier - case _ => Defines.anyTypeName - } - if (tpe.isEmpty) Defines.anyTypeName else tpe - } - -} + scope.popScope() + blockAst(node, calls.toList) + + private def astForCPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): Ast = + val name = ".constructorInitializer" + val callNode_ = + callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) + val args = c.getArguments.toList.map(a => astForNode(a)) + callAst(callNode_, args) + + private def astForCASTArrayRangeDesignator(des: CASTArrayRangeDesignator): Ast = + val op = Operators.arrayInitializer + val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) + val floorAst = nullSafeAst(des.getRangeFloor) + val ceilingAst = nullSafeAst(des.getRangeCeiling) + callAst(callNode_, List(floorAst, ceilingAst)) + + private def astForCPPASTArrayRangeDesignator(des: CPPASTArrayRangeDesignator): Ast = + val op = Operators.arrayInitializer + val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) + val floorAst = nullSafeAst(des.getRangeFloor) + val ceilingAst = nullSafeAst(des.getRangeCeiling) + callAst(callNode_, List(floorAst, ceilingAst)) + + protected def astForNode(node: IASTNode): Ast = + if config.includeFunctionBodies then astForNodeFull(node) else astForNodePartial(node) + + protected def astForNodeFull(node: IASTNode): Ast = + node match + case expr: IASTExpression => astForExpression(expr) + case name: IASTName => astForIdentifier(name) + case decl: IASTDeclSpecifier => astForIdentifier(decl) + case l: IASTInitializerList => astForInitializerList(l) + case c: ICPPASTConstructorInitializer => astForCPPASTConstructorInitializer(c) + case d: ICASTDesignatedInitializer => astForCASTDesignatedInitializer(d) + case d: ICPPASTDesignatedInitializer => astForCPPASTDesignatedInitializer(d) + case d: CASTArrayRangeDesignator => astForCASTArrayRangeDesignator(d) + case d: CPPASTArrayRangeDesignator => astForCPPASTArrayRangeDesignator(d) + case d: ICASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) + case d: ICPPASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) + case d: ICPPASTFieldDesignator => astForNode(d.getName) + case d: ICASTFieldDesignator => astForNode(d.getName) + case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) + case arrMod: IASTArrayModifier => astForArrayModifier(arrMod) + case _ => notHandledYet(node) + + protected def astForNodePartial(node: IASTNode): Ast = + node match + case expr: IASTExpression => astForExpression(expr) + case name: IASTName => astForIdentifier(name) + case decl: IASTDeclSpecifier => astForIdentifier(decl) + case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) + case _ => notHandledYet(node) + + protected def typeForDeclSpecifier( + spec: IASTNode, + stripKeywords: Boolean = true, + index: Int = 0 + ): String = + val tpe = spec match + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTFunctionDefinition] => + val parentDecl = s.getParent.asInstanceOf[IASTFunctionDefinition].getDeclarator + ASTStringUtil.getReturnTypeString(s, parentDecl) + case s: IASTSimpleDeclaration if s.getParent.isInstanceOf[ICASTKnRFunctionDeclarator] => + val decl = s.getDeclarators.toList(index) + pointersAsString(s.getDeclSpecifier, decl, stripKeywords) + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTSimpleDeclSpecifier => + ASTStringUtil.getReturnTypeString(s, null) + case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTCompositeTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTEnumerationSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTElaboratedTypeSpecifier + if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTElaboratedTypeSpecifier + if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTElaboratedTypeSpecifier => ASTStringUtil.getSignatureString(s, null) + // TODO: handle other types of IASTDeclSpecifier + case _ => Defines.anyTypeName + if tpe.isEmpty then Defines.anyTypeName else tpe + end typeForDeclSpecifier +end AstCreatorHelper diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala index d38593fe..3af4e498 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala @@ -8,323 +8,361 @@ import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.IGNUASTCompoundStatementExpression import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName -trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private def astForBinaryExpression(bin: IASTBinaryExpression): Ast = { - val op = bin.getOperator match { - case IASTBinaryExpression.op_multiply => Operators.multiplication - case IASTBinaryExpression.op_divide => Operators.division - case IASTBinaryExpression.op_modulo => Operators.modulo - case IASTBinaryExpression.op_plus => Operators.addition - case IASTBinaryExpression.op_minus => Operators.subtraction - case IASTBinaryExpression.op_shiftLeft => Operators.shiftLeft - case IASTBinaryExpression.op_shiftRight => Operators.arithmeticShiftRight - case IASTBinaryExpression.op_lessThan => Operators.lessThan - case IASTBinaryExpression.op_greaterThan => Operators.greaterThan - case IASTBinaryExpression.op_lessEqual => Operators.lessEqualsThan - case IASTBinaryExpression.op_greaterEqual => Operators.greaterEqualsThan - case IASTBinaryExpression.op_binaryAnd => Operators.and - case IASTBinaryExpression.op_binaryXor => Operators.xor - case IASTBinaryExpression.op_binaryOr => Operators.or - case IASTBinaryExpression.op_logicalAnd => Operators.logicalAnd - case IASTBinaryExpression.op_logicalOr => Operators.logicalOr - case IASTBinaryExpression.op_assign => Operators.assignment - case IASTBinaryExpression.op_multiplyAssign => Operators.assignmentMultiplication - case IASTBinaryExpression.op_divideAssign => Operators.assignmentDivision - case IASTBinaryExpression.op_moduloAssign => Operators.assignmentModulo - case IASTBinaryExpression.op_plusAssign => Operators.assignmentPlus - case IASTBinaryExpression.op_minusAssign => Operators.assignmentMinus - case IASTBinaryExpression.op_shiftLeftAssign => Operators.assignmentShiftLeft - case IASTBinaryExpression.op_shiftRightAssign => Operators.assignmentArithmeticShiftRight - case IASTBinaryExpression.op_binaryAndAssign => Operators.assignmentAnd - case IASTBinaryExpression.op_binaryXorAssign => Operators.assignmentXor - case IASTBinaryExpression.op_binaryOrAssign => Operators.assignmentOr - case IASTBinaryExpression.op_equals => Operators.equals - case IASTBinaryExpression.op_notequals => Operators.notEquals - case IASTBinaryExpression.op_pmdot => Operators.indirectFieldAccess - case IASTBinaryExpression.op_pmarrow => Operators.indirectFieldAccess - case IASTBinaryExpression.op_max => ".max" - case IASTBinaryExpression.op_min => ".min" - case IASTBinaryExpression.op_ellipses => ".op_ellipses" - case _ => ".unknown" - } - - val callNode_ = callNode(bin, nodeSignature(bin), op, op, DispatchTypes.STATIC_DISPATCH) - val left = nullSafeAst(bin.getOperand1) - val right = nullSafeAst(bin.getOperand2) - callAst(callNode_, List(left, right)) - } - - private def astForExpressionList(exprList: IASTExpressionList): Ast = { - val name = ".expressionList" - val callNode_ = - callNode(exprList, nodeSignature(exprList), name, name, DispatchTypes.STATIC_DISPATCH) - val childAsts = exprList.getExpressions.map(nullSafeAst) - callAst(callNode_, childAsts.toIndexedSeq) - } - - private def astForCallExpression(call: IASTFunctionCallExpression): Ast = { - val rec = call.getFunctionNameExpression match { - case unaryExpression: IASTUnaryExpression if unaryExpression.getOperand.isInstanceOf[IASTBinaryExpression] => - astForBinaryExpression(unaryExpression.getOperand.asInstanceOf[IASTBinaryExpression]) - case unaryExpression: IASTUnaryExpression if unaryExpression.getOperand.isInstanceOf[IASTFieldReference] => - astForFieldReference(unaryExpression.getOperand.asInstanceOf[IASTFieldReference]) - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTArraySubscriptExpression] => - astForArrayIndexExpression(unaryExpression.getOperand.asInstanceOf[IASTArraySubscriptExpression]) - case unaryExpression: IASTUnaryExpression if unaryExpression.getOperand.isInstanceOf[IASTConditionalExpression] => - astForUnaryExpression(unaryExpression) - case unaryExpression: IASTUnaryExpression if unaryExpression.getOperand.isInstanceOf[IASTUnaryExpression] => - astForUnaryExpression(unaryExpression.getOperand.asInstanceOf[IASTUnaryExpression]) - case lambdaExpression: ICPPASTLambdaExpression => - astForMethodRefForLambda(lambdaExpression) - case other => astForExpression(other) - } - - val (dd, name) = call.getFunctionNameExpression match { - case _: ICPPASTLambdaExpression => - (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewMethodRef].methodFullName) - case _ if rec.root.exists(_.isInstanceOf[NewIdentifier]) => - (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewIdentifier].name) - case _ - if rec.root.exists(_.isInstanceOf[NewCall]) && call.getFunctionNameExpression - .isInstanceOf[IASTFieldReference] => - ( - DispatchTypes.STATIC_DISPATCH, - nodeSignature(call.getFunctionNameExpression.asInstanceOf[IASTFieldReference].getFieldName) - ) - case _ if rec.root.exists(_.isInstanceOf[NewCall]) => - (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewCall].code) - case reference: IASTIdExpression => - (DispatchTypes.STATIC_DISPATCH, nodeSignature(reference)) - case _ => - (DispatchTypes.STATIC_DISPATCH, "") - } - - val shortName = fixQualifiedName(name) - val fullName = typeFor(call.getFunctionNameExpression) match { - case t if t != Defines.anyTypeName => s"${dereferenceTypeFullName(t)}.$shortName" - case _ => shortName - } - val cpgCall = callNode(call, nodeSignature(call), shortName, fullName, dd) - val args = call.getArguments.toList.map(a => astForNode(a)) - rec.root match { - // Optimization: do not include the receiver if the receiver is just the function name, - // e.g., for `f(x)`, don't include an `f` identifier node as a first child. Since we - // have so many call sites in CPGs, this drastically reduces the number of nodes. - // Moreover, the data flow tracker does not need to track `f`, which would not make - // much sense anyway. - case Some(r: NewIdentifier) if r.name == shortName => - callAst(cpgCall, args) - case Some(_) => - callAst(cpgCall, args, Option(rec)) - case None => - callAst(cpgCall, args) - } - } - - private def astForUnaryExpression(unary: IASTUnaryExpression): Ast = { - val operatorMethod = unary.getOperator match { - case IASTUnaryExpression.op_prefixIncr => Operators.preIncrement - case IASTUnaryExpression.op_prefixDecr => Operators.preDecrement - case IASTUnaryExpression.op_plus => Operators.plus - case IASTUnaryExpression.op_minus => Operators.minus - case IASTUnaryExpression.op_star => Operators.indirection - case IASTUnaryExpression.op_amper => Operators.addressOf - case IASTUnaryExpression.op_tilde => Operators.not - case IASTUnaryExpression.op_not => Operators.logicalNot - case IASTUnaryExpression.op_sizeof => Operators.sizeOf - case IASTUnaryExpression.op_postFixIncr => Operators.postIncrement - case IASTUnaryExpression.op_postFixDecr => Operators.postDecrement - case IASTUnaryExpression.op_throw => ".throw" - case IASTUnaryExpression.op_typeid => ".typeOf" - case IASTUnaryExpression.op_bracketedPrimary => ".bracketedPrimary" - case _ => ".unknown" - } - - if ( - unary.getOperator == IASTUnaryExpression.op_bracketedPrimary && - !unary.getOperand.isInstanceOf[IASTExpressionList] - ) { - nullSafeAst(unary.getOperand) - } else { - val cpgUnary = - callNode(unary, nodeSignature(unary), operatorMethod, operatorMethod, DispatchTypes.STATIC_DISPATCH) - val operand = nullSafeAst(unary.getOperand) - callAst(cpgUnary, List(operand)) - } - } - - private def astForTypeIdExpression(typeId: IASTTypeIdExpression): Ast = { - typeId.getOperator match { - case op - if op == IASTTypeIdExpression.op_sizeof || - op == IASTTypeIdExpression.op_sizeofParameterPack || - op == IASTTypeIdExpression.op_typeid || - op == IASTTypeIdExpression.op_alignof || - op == IASTTypeIdExpression.op_typeof => - val call = - callNode(typeId, nodeSignature(typeId), Operators.sizeOf, Operators.sizeOf, DispatchTypes.STATIC_DISPATCH) - val arg = astForNode(typeId.getTypeId.getDeclSpecifier) - callAst(call, List(arg)) - case _ => notHandledYet(typeId) - } - } - - private def astForConditionalExpression(expr: IASTConditionalExpression): Ast = { - val name = Operators.conditional - val call = callNode(expr, nodeSignature(expr), name, name, DispatchTypes.STATIC_DISPATCH) - - val condAst = nullSafeAst(expr.getLogicalConditionExpression) - val posAst = nullSafeAst(expr.getPositiveResultExpression) - val negAst = nullSafeAst(expr.getNegativeResultExpression) - - val children = List(condAst, posAst, negAst) - callAst(call, children) - } - - private def astForArrayIndexExpression(arrayIndexExpression: IASTArraySubscriptExpression): Ast = { - val name = Operators.indirectIndexAccess - val cpgArrayIndexing = - callNode(arrayIndexExpression, nodeSignature(arrayIndexExpression), name, name, DispatchTypes.STATIC_DISPATCH) - - val expr = astForExpression(arrayIndexExpression.getArrayExpression) - val arg = astForNode(arrayIndexExpression.getArgument) - callAst(cpgArrayIndexing, List(expr, arg)) - } - - private def astForCastExpression(castExpression: IASTCastExpression): Ast = { - val cpgCastExpression = - callNode( - castExpression, - nodeSignature(castExpression), - Operators.cast, - Operators.cast, - DispatchTypes.STATIC_DISPATCH - ) - - val expr = astForExpression(castExpression.getOperand) - val argNode = castExpression.getTypeId - val arg = unknownNode(argNode, nodeSignature(argNode)) - - callAst(cpgCastExpression, List(Ast(arg), expr)) - } - - private def astsForConstructorInitializer(initializer: IASTInitializer): List[Ast] = { - initializer match { - case init: ICPPASTConstructorInitializer => init.getArguments.toList.map(x => astForNode(x)) - case _ => Nil // null or unexpected type - } - } - - private def astsForInitializerPlacements(initializerPlacements: Array[IASTInitializerClause]): List[Ast] = { - if (initializerPlacements != null) initializerPlacements.toList.map(x => astForNode(x)) - else Nil - } - - private def astForNewExpression(newExpression: ICPPASTNewExpression): Ast = { - val name = ".new" - val cpgNewExpression = - callNode(newExpression, nodeSignature(newExpression), name, name, DispatchTypes.STATIC_DISPATCH) - - val typeId = newExpression.getTypeId - if (newExpression.isArrayAllocation) { - val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) - Ast(cpgNewExpression).withChild(cpgTypeId).withArgEdge(cpgNewExpression, cpgTypeId.root.get) - } else { - val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) - val args = astsForConstructorInitializer(newExpression.getInitializer) ++ - astsForInitializerPlacements(newExpression.getPlacementArguments) - callAst(cpgNewExpression, List(cpgTypeId) ++ args) - } - } - - private def astForDeleteExpression(delExpression: ICPPASTDeleteExpression): Ast = { - val name = Operators.delete - val cpgDeleteNode = - callNode(delExpression, nodeSignature(delExpression), name, name, DispatchTypes.STATIC_DISPATCH) - val arg = astForExpression(delExpression.getOperand) - callAst(cpgDeleteNode, List(arg)) - } - - private def astForTypeIdInitExpression(typeIdInit: IASTTypeIdInitializerExpression): Ast = { - val name = Operators.cast - val cpgCastExpression = - callNode(typeIdInit, nodeSignature(typeIdInit), name, name, DispatchTypes.STATIC_DISPATCH) - - val typeAst = unknownNode(typeIdInit.getTypeId, nodeSignature(typeIdInit.getTypeId)) - val expr = astForNode(typeIdInit.getInitializer) - callAst(cpgCastExpression, List(Ast(typeAst), expr)) - } - - private def astForConstructorExpression(c: ICPPASTSimpleTypeConstructorExpression): Ast = { - val name = c.getDeclSpecifier.toString - val callNode_ = callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) - val arg = astForNode(c.getInitializer) - callAst(callNode_, List(arg)) - } - - private def astForCompoundStatementExpression(compoundExpression: IGNUASTCompoundStatementExpression): Ast = - nullSafeAst(compoundExpression.getCompoundStatement).headOption.getOrElse(Ast()) - - private def astForPackExpansionExpression(packExpansionExpression: ICPPASTPackExpansionExpression): Ast = - astForExpression(packExpansionExpression.getPattern) - - protected def astForExpression(expression: IASTExpression): Ast = { - if (config.includeFunctionBodies) - astForExpressionFull(expression) - else astForExpressionPartial(expression) - } - - protected def astForExpressionFull(expression: IASTExpression): Ast = { - val r = expression match { - case lit: IASTLiteralExpression => astForLiteral(lit) - case un: IASTUnaryExpression => astForUnaryExpression(un) - case bin: IASTBinaryExpression => astForBinaryExpression(bin) - case exprList: IASTExpressionList => astForExpressionList(exprList) - case idExpr: IASTIdExpression => astForIdExpression(idExpr) - case call: IASTFunctionCallExpression => astForCallExpression(call) - case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) - case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) - case expr: IASTConditionalExpression => astForConditionalExpression(expr) - case arr: IASTArraySubscriptExpression => astForArrayIndexExpression(arr) - case castExpression: IASTCastExpression => astForCastExpression(castExpression) - case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) - case delExpression: ICPPASTDeleteExpression => astForDeleteExpression(delExpression) - case typeIdInit: IASTTypeIdInitializerExpression => astForTypeIdInitExpression(typeIdInit) - case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) - case lambdaExpression: ICPPASTLambdaExpression => astForMethodRefForLambda(lambdaExpression) - case cExpr: IGNUASTCompoundStatementExpression => astForCompoundStatementExpression(cExpr) - case pExpr: ICPPASTPackExpansionExpression => astForPackExpansionExpression(pExpr) - case _ => notHandledYet(expression) - } - asChildOfMacroCall(expression, r) - } - - protected def astForExpressionPartial(expression: IASTExpression): Ast = { - val r = expression match { - case call: IASTFunctionCallExpression => astForCallExpression(call) - case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) - case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) - case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) - case typeIdInit: IASTTypeIdInitializerExpression => astForTypeIdInitExpression(typeIdInit) - case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) - case _ => notHandledYet(expression) - } - asChildOfMacroCall(expression, r) - } - - private def astForIdExpression(idExpression: IASTIdExpression): Ast = idExpression.getName match { - case name: CPPASTQualifiedName => astForQualifiedName(name) - case _ => astForIdentifier(idExpression) - } - - protected def astForStaticAssert(a: ICPPASTStaticAssertDeclaration): Ast = { - val name = "static_assert" - val call = callNode(a, nodeSignature(a), name, name, DispatchTypes.STATIC_DISPATCH) - val cond = nullSafeAst(a.getCondition) - val messg = nullSafeAst(a.getMessage) - callAst(call, List(cond, messg)) - } - -} +trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private def astForBinaryExpression(bin: IASTBinaryExpression): Ast = + val op = bin.getOperator match + case IASTBinaryExpression.op_multiply => Operators.multiplication + case IASTBinaryExpression.op_divide => Operators.division + case IASTBinaryExpression.op_modulo => Operators.modulo + case IASTBinaryExpression.op_plus => Operators.addition + case IASTBinaryExpression.op_minus => Operators.subtraction + case IASTBinaryExpression.op_shiftLeft => Operators.shiftLeft + case IASTBinaryExpression.op_shiftRight => Operators.arithmeticShiftRight + case IASTBinaryExpression.op_lessThan => Operators.lessThan + case IASTBinaryExpression.op_greaterThan => Operators.greaterThan + case IASTBinaryExpression.op_lessEqual => Operators.lessEqualsThan + case IASTBinaryExpression.op_greaterEqual => Operators.greaterEqualsThan + case IASTBinaryExpression.op_binaryAnd => Operators.and + case IASTBinaryExpression.op_binaryXor => Operators.xor + case IASTBinaryExpression.op_binaryOr => Operators.or + case IASTBinaryExpression.op_logicalAnd => Operators.logicalAnd + case IASTBinaryExpression.op_logicalOr => Operators.logicalOr + case IASTBinaryExpression.op_assign => Operators.assignment + case IASTBinaryExpression.op_multiplyAssign => Operators.assignmentMultiplication + case IASTBinaryExpression.op_divideAssign => Operators.assignmentDivision + case IASTBinaryExpression.op_moduloAssign => Operators.assignmentModulo + case IASTBinaryExpression.op_plusAssign => Operators.assignmentPlus + case IASTBinaryExpression.op_minusAssign => Operators.assignmentMinus + case IASTBinaryExpression.op_shiftLeftAssign => Operators.assignmentShiftLeft + case IASTBinaryExpression.op_shiftRightAssign => + Operators.assignmentArithmeticShiftRight + case IASTBinaryExpression.op_binaryAndAssign => Operators.assignmentAnd + case IASTBinaryExpression.op_binaryXorAssign => Operators.assignmentXor + case IASTBinaryExpression.op_binaryOrAssign => Operators.assignmentOr + case IASTBinaryExpression.op_equals => Operators.equals + case IASTBinaryExpression.op_notequals => Operators.notEquals + case IASTBinaryExpression.op_pmdot => Operators.indirectFieldAccess + case IASTBinaryExpression.op_pmarrow => Operators.indirectFieldAccess + case IASTBinaryExpression.op_max => ".max" + case IASTBinaryExpression.op_min => ".min" + case IASTBinaryExpression.op_ellipses => ".op_ellipses" + case _ => ".unknown" + + val callNode_ = callNode(bin, nodeSignature(bin), op, op, DispatchTypes.STATIC_DISPATCH) + val left = nullSafeAst(bin.getOperand1) + val right = nullSafeAst(bin.getOperand2) + callAst(callNode_, List(left, right)) + end astForBinaryExpression + + private def astForExpressionList(exprList: IASTExpressionList): Ast = + val name = ".expressionList" + val callNode_ = + callNode(exprList, nodeSignature(exprList), name, name, DispatchTypes.STATIC_DISPATCH) + val childAsts = exprList.getExpressions.map(nullSafeAst) + callAst(callNode_, childAsts.toIndexedSeq) + + private def astForCallExpression(call: IASTFunctionCallExpression): Ast = + val rec = call.getFunctionNameExpression match + case unaryExpression: IASTUnaryExpression + if unaryExpression.getOperand.isInstanceOf[IASTBinaryExpression] => + astForBinaryExpression( + unaryExpression.getOperand.asInstanceOf[IASTBinaryExpression] + ) + case unaryExpression: IASTUnaryExpression + if unaryExpression.getOperand.isInstanceOf[IASTFieldReference] => + astForFieldReference(unaryExpression.getOperand.asInstanceOf[IASTFieldReference]) + case unaryExpression: IASTUnaryExpression + if unaryExpression.getOperand.isInstanceOf[IASTArraySubscriptExpression] => + astForArrayIndexExpression( + unaryExpression.getOperand.asInstanceOf[IASTArraySubscriptExpression] + ) + case unaryExpression: IASTUnaryExpression + if unaryExpression.getOperand.isInstanceOf[IASTConditionalExpression] => + astForUnaryExpression(unaryExpression) + case unaryExpression: IASTUnaryExpression + if unaryExpression.getOperand.isInstanceOf[IASTUnaryExpression] => + astForUnaryExpression(unaryExpression.getOperand.asInstanceOf[IASTUnaryExpression]) + case lambdaExpression: ICPPASTLambdaExpression => + astForMethodRefForLambda(lambdaExpression) + case other => astForExpression(other) + + val (dd, name) = call.getFunctionNameExpression match + case _: ICPPASTLambdaExpression => + ( + DispatchTypes.STATIC_DISPATCH, + rec.root.get.asInstanceOf[NewMethodRef].methodFullName + ) + case _ if rec.root.exists(_.isInstanceOf[NewIdentifier]) => + (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewIdentifier].name) + case _ + if rec.root.exists(_.isInstanceOf[NewCall]) && call.getFunctionNameExpression + .isInstanceOf[IASTFieldReference] => + ( + DispatchTypes.STATIC_DISPATCH, + nodeSignature( + call.getFunctionNameExpression.asInstanceOf[IASTFieldReference].getFieldName + ) + ) + case _ if rec.root.exists(_.isInstanceOf[NewCall]) => + (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewCall].code) + case reference: IASTIdExpression => + (DispatchTypes.STATIC_DISPATCH, nodeSignature(reference)) + case _ => + (DispatchTypes.STATIC_DISPATCH, "") + + val shortName = fixQualifiedName(name) + val fullName = typeFor(call.getFunctionNameExpression) match + case t if t != Defines.anyTypeName => s"${dereferenceTypeFullName(t)}.$shortName" + case _ => shortName + val cpgCall = callNode(call, nodeSignature(call), shortName, fullName, dd) + val args = call.getArguments.toList.map(a => astForNode(a)) + rec.root match + // Optimization: do not include the receiver if the receiver is just the function name, + // e.g., for `f(x)`, don't include an `f` identifier node as a first child. Since we + // have so many call sites in CPGs, this drastically reduces the number of nodes. + // Moreover, the data flow tracker does not need to track `f`, which would not make + // much sense anyway. + case Some(r: NewIdentifier) if r.name == shortName => + callAst(cpgCall, args) + case Some(_) => + callAst(cpgCall, args, Option(rec)) + case None => + callAst(cpgCall, args) + end astForCallExpression + + private def astForUnaryExpression(unary: IASTUnaryExpression): Ast = + val operatorMethod = unary.getOperator match + case IASTUnaryExpression.op_prefixIncr => Operators.preIncrement + case IASTUnaryExpression.op_prefixDecr => Operators.preDecrement + case IASTUnaryExpression.op_plus => Operators.plus + case IASTUnaryExpression.op_minus => Operators.minus + case IASTUnaryExpression.op_star => Operators.indirection + case IASTUnaryExpression.op_amper => Operators.addressOf + case IASTUnaryExpression.op_tilde => Operators.not + case IASTUnaryExpression.op_not => Operators.logicalNot + case IASTUnaryExpression.op_sizeof => Operators.sizeOf + case IASTUnaryExpression.op_postFixIncr => Operators.postIncrement + case IASTUnaryExpression.op_postFixDecr => Operators.postDecrement + case IASTUnaryExpression.op_throw => ".throw" + case IASTUnaryExpression.op_typeid => ".typeOf" + case IASTUnaryExpression.op_bracketedPrimary => ".bracketedPrimary" + case _ => ".unknown" + + if + unary.getOperator == IASTUnaryExpression.op_bracketedPrimary && + !unary.getOperand.isInstanceOf[IASTExpressionList] + then + nullSafeAst(unary.getOperand) + else + val cpgUnary = + callNode( + unary, + nodeSignature(unary), + operatorMethod, + operatorMethod, + DispatchTypes.STATIC_DISPATCH + ) + val operand = nullSafeAst(unary.getOperand) + callAst(cpgUnary, List(operand)) + end astForUnaryExpression + + private def astForTypeIdExpression(typeId: IASTTypeIdExpression): Ast = + typeId.getOperator match + case op + if op == IASTTypeIdExpression.op_sizeof || + op == IASTTypeIdExpression.op_sizeofParameterPack || + op == IASTTypeIdExpression.op_typeid || + op == IASTTypeIdExpression.op_alignof || + op == IASTTypeIdExpression.op_typeof => + val call = + callNode( + typeId, + nodeSignature(typeId), + Operators.sizeOf, + Operators.sizeOf, + DispatchTypes.STATIC_DISPATCH + ) + val arg = astForNode(typeId.getTypeId.getDeclSpecifier) + callAst(call, List(arg)) + case _ => notHandledYet(typeId) + + private def astForConditionalExpression(expr: IASTConditionalExpression): Ast = + val name = Operators.conditional + val call = callNode(expr, nodeSignature(expr), name, name, DispatchTypes.STATIC_DISPATCH) + + val condAst = nullSafeAst(expr.getLogicalConditionExpression) + val posAst = nullSafeAst(expr.getPositiveResultExpression) + val negAst = nullSafeAst(expr.getNegativeResultExpression) + + val children = List(condAst, posAst, negAst) + callAst(call, children) + + private def astForArrayIndexExpression(arrayIndexExpression: IASTArraySubscriptExpression) + : Ast = + val name = Operators.indirectIndexAccess + val cpgArrayIndexing = + callNode( + arrayIndexExpression, + nodeSignature(arrayIndexExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + + val expr = astForExpression(arrayIndexExpression.getArrayExpression) + val arg = astForNode(arrayIndexExpression.getArgument) + callAst(cpgArrayIndexing, List(expr, arg)) + + private def astForCastExpression(castExpression: IASTCastExpression): Ast = + val cpgCastExpression = + callNode( + castExpression, + nodeSignature(castExpression), + Operators.cast, + Operators.cast, + DispatchTypes.STATIC_DISPATCH + ) + + val expr = astForExpression(castExpression.getOperand) + val argNode = castExpression.getTypeId + val arg = unknownNode(argNode, nodeSignature(argNode)) + + callAst(cpgCastExpression, List(Ast(arg), expr)) + + private def astsForConstructorInitializer(initializer: IASTInitializer): List[Ast] = + initializer match + case init: ICPPASTConstructorInitializer => + init.getArguments.toList.map(x => astForNode(x)) + case _ => Nil // null or unexpected type + + private def astsForInitializerPlacements(initializerPlacements: Array[IASTInitializerClause]) + : List[Ast] = + if initializerPlacements != null then initializerPlacements.toList.map(x => astForNode(x)) + else Nil + + private def astForNewExpression(newExpression: ICPPASTNewExpression): Ast = + val name = ".new" + val cpgNewExpression = + callNode( + newExpression, + nodeSignature(newExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + + val typeId = newExpression.getTypeId + if newExpression.isArrayAllocation then + val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) + Ast(cpgNewExpression).withChild(cpgTypeId).withArgEdge( + cpgNewExpression, + cpgTypeId.root.get + ) + else + val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) + val args = astsForConstructorInitializer(newExpression.getInitializer) ++ + astsForInitializerPlacements(newExpression.getPlacementArguments) + callAst(cpgNewExpression, List(cpgTypeId) ++ args) + end astForNewExpression + + private def astForDeleteExpression(delExpression: ICPPASTDeleteExpression): Ast = + val name = Operators.delete + val cpgDeleteNode = + callNode( + delExpression, + nodeSignature(delExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + val arg = astForExpression(delExpression.getOperand) + callAst(cpgDeleteNode, List(arg)) + + private def astForTypeIdInitExpression(typeIdInit: IASTTypeIdInitializerExpression): Ast = + val name = Operators.cast + val cpgCastExpression = + callNode( + typeIdInit, + nodeSignature(typeIdInit), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + + val typeAst = unknownNode(typeIdInit.getTypeId, nodeSignature(typeIdInit.getTypeId)) + val expr = astForNode(typeIdInit.getInitializer) + callAst(cpgCastExpression, List(Ast(typeAst), expr)) + + private def astForConstructorExpression(c: ICPPASTSimpleTypeConstructorExpression): Ast = + val name = c.getDeclSpecifier.toString + val callNode_ = callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) + val arg = astForNode(c.getInitializer) + callAst(callNode_, List(arg)) + + private def astForCompoundStatementExpression( + compoundExpression: IGNUASTCompoundStatementExpression + ): Ast = + nullSafeAst(compoundExpression.getCompoundStatement).headOption.getOrElse(Ast()) + + private def astForPackExpansionExpression( + packExpansionExpression: ICPPASTPackExpansionExpression + ): Ast = + astForExpression(packExpansionExpression.getPattern) + + protected def astForExpression(expression: IASTExpression): Ast = + if config.includeFunctionBodies then + astForExpressionFull(expression) + else astForExpressionPartial(expression) + + protected def astForExpressionFull(expression: IASTExpression): Ast = + val r = expression match + case lit: IASTLiteralExpression => astForLiteral(lit) + case un: IASTUnaryExpression => astForUnaryExpression(un) + case bin: IASTBinaryExpression => astForBinaryExpression(bin) + case exprList: IASTExpressionList => astForExpressionList(exprList) + case idExpr: IASTIdExpression => astForIdExpression(idExpr) + case call: IASTFunctionCallExpression => astForCallExpression(call) + case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) + case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) + case expr: IASTConditionalExpression => astForConditionalExpression(expr) + case arr: IASTArraySubscriptExpression => astForArrayIndexExpression(arr) + case castExpression: IASTCastExpression => astForCastExpression(castExpression) + case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) + case delExpression: ICPPASTDeleteExpression => astForDeleteExpression(delExpression) + case typeIdInit: IASTTypeIdInitializerExpression => + astForTypeIdInitExpression(typeIdInit) + case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) + case lambdaExpression: ICPPASTLambdaExpression => + astForMethodRefForLambda(lambdaExpression) + case cExpr: IGNUASTCompoundStatementExpression => + astForCompoundStatementExpression(cExpr) + case pExpr: ICPPASTPackExpansionExpression => astForPackExpansionExpression(pExpr) + case _ => notHandledYet(expression) + asChildOfMacroCall(expression, r) + end astForExpressionFull + + protected def astForExpressionPartial(expression: IASTExpression): Ast = + val r = expression match + case call: IASTFunctionCallExpression => astForCallExpression(call) + case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) + case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) + case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) + case typeIdInit: IASTTypeIdInitializerExpression => + astForTypeIdInitExpression(typeIdInit) + case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) + case _ => notHandledYet(expression) + asChildOfMacroCall(expression, r) + + private def astForIdExpression(idExpression: IASTIdExpression): Ast = idExpression.getName match + case name: CPPASTQualifiedName => astForQualifiedName(name) + case _ => astForIdentifier(idExpression) + + protected def astForStaticAssert(a: ICPPASTStaticAssertDeclaration): Ast = + val name = "static_assert" + val call = callNode(a, nodeSignature(a), name, name, DispatchTypes.STATIC_DISPATCH) + val cond = nullSafeAst(a.getCondition) + val messg = nullSafeAst(a.getMessage) + callAst(call, List(cond, messg)) +end AstForExpressionsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala index 0e93d46d..5a7b4918 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala @@ -7,7 +7,10 @@ import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTLambdaExpression import org.eclipse.cdt.core.dom.ast.gnu.c.ICASTKnRFunctionDeclarator import org.eclipse.cdt.internal.core.dom.parser.c.{CASTFunctionDeclarator, CASTParameterDeclaration} -import org.eclipse.cdt.internal.core.dom.parser.cpp.{CPPASTFunctionDeclarator, CPPASTParameterDeclaration} +import org.eclipse.cdt.internal.core.dom.parser.cpp.{ + CPPASTFunctionDeclarator, + CPPASTParameterDeclaration +} import org.eclipse.cdt.internal.core.model.ASTStringUtil import io.appthreat.x2cpg.datastructures.Stack.* import org.apache.commons.lang.StringUtils @@ -15,222 +18,246 @@ import org.apache.commons.lang.StringUtils import scala.annotation.tailrec import scala.collection.mutable -trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private val seenFunctionSignatures = mutable.HashSet.empty[String] - - private def createFunctionTypeAndTypeDecl( - node: IASTNode, - method: NewMethod, - methodName: String, - methodFullName: String, - signature: String - ): Ast = { - val normalizedName = StringUtils.normalizeSpace(methodName) - val normalizedFullName = StringUtils.normalizeSpace(methodFullName) - - val parentNode: NewTypeDecl = methodAstParentStack.collectFirst { case t: NewTypeDecl => t }.getOrElse { - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - val typeDeclNode_ = typeDeclNode( - node, - normalizedName, - normalizedFullName, - method.filename, - normalizedName, - astParentType, - astParentFullName - ) - Ast.storeInDiffGraph(Ast(typeDeclNode_), diffGraph) - typeDeclNode_ - } - - method.astParentFullName = parentNode.fullName - method.astParentType = parentNode.label - val functionBinding = NewBinding().name(normalizedName).methodFullName(normalizedFullName).signature(signature) - Ast(functionBinding).withBindsEdge(parentNode, functionBinding).withRefEdge(functionBinding, method) - } - - private def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match { - case arr: IASTArrayDeclarator => parameters(arr.getNestedDeclarator) - case decl: CPPASTFunctionDeclarator => decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) - case decl: CASTFunctionDeclarator => decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) - case defn: IASTFunctionDefinition => parameters(defn.getDeclarator) - case lambdaExpression: ICPPASTLambdaExpression => parameters(lambdaExpression.getDeclarator) - case knr: ICASTKnRFunctionDeclarator => knr.getParameterDeclarations.toIndexedSeq - case _: IASTDeclarator => Seq.empty - case other if other != null => notHandledYet(other); Seq.empty - case null => Seq.empty - } - - @tailrec - private def isVariadic(functionNode: IASTNode): Boolean = functionNode match { - case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() - case decl: CASTFunctionDeclarator => decl.takesVarArgs() - case defn: IASTFunctionDefinition => isVariadic(defn.getDeclarator) - case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) - case _ => false - } - - private def parameterListSignature(func: IASTNode): String = { - val variadic = if (isVariadic(func)) "..." else "" - val elements = parameters(func).map { - case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) - case other => typeForDeclSpecifier(other) - } - s"(${elements.mkString(",")}$variadic)" - } - - private def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = { - parameterNodes.lastOption.foreach { - case p: NewMethodParameterIn if isVariadic(func) => - p.isVariadic = true - p.code = s"${p.code}..." - case _ => - } - } - - protected def astForMethodRefForLambda(lambdaExpression: ICPPASTLambdaExpression): Ast = { - val filename = fileName(lambdaExpression) - - val returnType = lambdaExpression.getDeclarator match { - case declarator: IASTDeclarator => - declarator.getTrailingReturnType match { - case id: IASTTypeId => typeForDeclSpecifier(id.getDeclSpecifier) - case null => Defines.anyTypeName +trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private val seenFunctionSignatures = mutable.HashSet.empty[String] + + private def createFunctionTypeAndTypeDecl( + node: IASTNode, + method: NewMethod, + methodName: String, + methodFullName: String, + signature: String + ): Ast = + val normalizedName = StringUtils.normalizeSpace(methodName) + val normalizedFullName = StringUtils.normalizeSpace(methodFullName) + + val parentNode: NewTypeDecl = methodAstParentStack.collectFirst { case t: NewTypeDecl => + t + }.getOrElse { + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val typeDeclNode_ = typeDeclNode( + node, + normalizedName, + normalizedFullName, + method.filename, + normalizedName, + astParentType, + astParentFullName + ) + Ast.storeInDiffGraph(Ast(typeDeclNode_), diffGraph) + typeDeclNode_ + } + + method.astParentFullName = parentNode.fullName + method.astParentType = parentNode.label + val functionBinding = NewBinding().name(normalizedName).methodFullName( + normalizedFullName + ).signature(signature) + Ast(functionBinding).withBindsEdge(parentNode, functionBinding).withRefEdge( + functionBinding, + method + ) + end createFunctionTypeAndTypeDecl + + private def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match + case arr: IASTArrayDeclarator => parameters(arr.getNestedDeclarator) + case decl: CPPASTFunctionDeclarator => + decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) + case decl: CASTFunctionDeclarator => + decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) + case defn: IASTFunctionDefinition => parameters(defn.getDeclarator) + case lambdaExpression: ICPPASTLambdaExpression => parameters(lambdaExpression.getDeclarator) + case knr: ICASTKnRFunctionDeclarator => knr.getParameterDeclarations.toIndexedSeq + case _: IASTDeclarator => Seq.empty + case other if other != null => notHandledYet(other); Seq.empty + case null => Seq.empty + + @tailrec + private def isVariadic(functionNode: IASTNode): Boolean = functionNode match + case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() + case decl: CASTFunctionDeclarator => decl.takesVarArgs() + case defn: IASTFunctionDefinition => isVariadic(defn.getDeclarator) + case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) + case _ => false + + private def parameterListSignature(func: IASTNode): String = + val variadic = if isVariadic(func) then "..." else "" + val elements = parameters(func).map { + case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) + case other => typeForDeclSpecifier(other) + } + s"(${elements.mkString(",")}$variadic)" + + private def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = + parameterNodes.lastOption.foreach { + case p: NewMethodParameterIn if isVariadic(func) => + p.isVariadic = true + p.code = s"${p.code}..." + case _ => + } + + protected def astForMethodRefForLambda(lambdaExpression: ICPPASTLambdaExpression): Ast = + val filename = fileName(lambdaExpression) + + val returnType = lambdaExpression.getDeclarator match + case declarator: IASTDeclarator => + declarator.getTrailingReturnType match + case id: IASTTypeId => typeForDeclSpecifier(id.getDeclSpecifier) + case null => Defines.anyTypeName + case null => Defines.anyTypeName + val (name, fullname) = uniqueName("lambda", "", fullName(lambdaExpression)) + val signature = s"$returnType $fullname ${parameterListSignature(lambdaExpression)}" + val code = nodeSignature(lambdaExpression) + val methodNode_ = + methodNode(lambdaExpression, name, code, fullname, Some(signature), filename) + + scope.pushNewScope(methodNode_) + val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => + parameterNode(p, i) } - case null => Defines.anyTypeName - } - val (name, fullname) = uniqueName("lambda", "", fullName(lambdaExpression)) - val signature = s"$returnType $fullname ${parameterListSignature(lambdaExpression)}" - val code = nodeSignature(lambdaExpression) - val methodNode_ = methodNode(lambdaExpression, name, code, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, lambdaExpression) - - scope.popScope() - - val astForLambda = methodAst( - methodNode_, - parameterNodes.map(Ast(_)), - astForMethodBody(Option(lambdaExpression.getBody)), - newMethodReturnNode(lambdaExpression, registerType(returnType)) - ) - val typeDeclAst = createFunctionTypeAndTypeDecl(lambdaExpression, methodNode_, name, fullname, signature) - Ast.storeInDiffGraph(astForLambda.merge(typeDeclAst), diffGraph) - - Ast(methodRefNode(lambdaExpression, code, fullname, methodNode_.astParentFullName)) - } - - protected def astForFunctionDeclarator(funcDecl: IASTFunctionDeclarator): Ast = { - val returnType = typeForDeclSpecifier(funcDecl.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier) - val fullname = fullName(funcDecl) - val templateParams = templateParameters(funcDecl).getOrElse("") - val signature = - s"$returnType $fullname$templateParams ${parameterListSignature(funcDecl)}" - - if (seenFunctionSignatures.add(signature)) { - val name = shortName(funcDecl) - val code = nodeSignature(funcDecl.getParent) - val filename = fileName(funcDecl) - val methodNode_ = methodNode(funcDecl, name, code, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - - val parameterNodes = withIndex(parameters(funcDecl)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, funcDecl) - - scope.popScope() - - val stubAst = methodStubAst(methodNode_, parameterNodes, newMethodReturnNode(funcDecl, registerType(returnType))) - val typeDeclAst = createFunctionTypeAndTypeDecl(funcDecl, methodNode_, name, fullname, signature) - stubAst.merge(typeDeclAst) - } else { - Ast() - } - } - - protected def astForFunctionDefinition(funcDef: IASTFunctionDefinition): Ast = { - val filename = fileName(funcDef) - val returnType = typeForDeclSpecifier(funcDef.getDeclSpecifier) - val name = shortName(funcDef) - val fullname = fullName(funcDef) - val templateParams = templateParameters(funcDef).getOrElse("") - - val signature = - s"$returnType $fullname$templateParams ${parameterListSignature(funcDef)}" - seenFunctionSignatures.add(signature) - - val code = nodeSignature(funcDef) - val methodNode_ = methodNode(funcDef, name, code, fullname, Some(signature), filename) - - methodAstParentStack.push(methodNode_) - scope.pushNewScope(methodNode_) - - val parameterNodes = withIndex(parameters(funcDef)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, funcDef) - - val astForMethod = methodAst( - methodNode_, - parameterNodes.map(Ast(_)), - astForMethodBody(Option(funcDef.getBody)), - newMethodReturnNode(funcDef, registerType(returnType)) - ) - - scope.popScope() - methodAstParentStack.pop() - - val typeDeclAst = createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullname, signature) - astForMethod.merge(typeDeclAst) - } - - private def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = { - val (name, code, tpe, variadic) = parameter match { - case p: CASTParameterDeclaration => - ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), - nodeSignature(p), - cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), - false + setVariadic(parameterNodes, lambdaExpression) + + scope.popScope() + + val astForLambda = methodAst( + methodNode_, + parameterNodes.map(Ast(_)), + astForMethodBody(Option(lambdaExpression.getBody)), + newMethodReturnNode(lambdaExpression, registerType(returnType)) ) - case p: CPPASTParameterDeclaration => - ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), - nodeSignature(p), - cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), - p.getDeclarator.declaresParameterPack() + val typeDeclAst = + createFunctionTypeAndTypeDecl(lambdaExpression, methodNode_, name, fullname, signature) + Ast.storeInDiffGraph(astForLambda.merge(typeDeclAst), diffGraph) + + Ast(methodRefNode(lambdaExpression, code, fullname, methodNode_.astParentFullName)) + end astForMethodRefForLambda + + protected def astForFunctionDeclarator(funcDecl: IASTFunctionDeclarator): Ast = + val returnType = typeForDeclSpecifier( + funcDecl.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier ) - case s: IASTSimpleDeclaration => - ( - s.getDeclarators.headOption - .map(n => ASTStringUtil.getSimpleName(n.getName)) - .getOrElse(uniqueName("parameter", "", "")._1), - nodeSignature(s), - cleanType(typeForDeclSpecifier(s)), - false + val fullname = fullName(funcDecl) + val templateParams = templateParameters(funcDecl).getOrElse("") + val signature = + s"$returnType $fullname$templateParams ${parameterListSignature(funcDecl)}" + + if seenFunctionSignatures.add(signature) then + val name = shortName(funcDecl) + val code = nodeSignature(funcDecl.getParent) + val filename = fileName(funcDecl) + val methodNode_ = methodNode(funcDecl, name, code, fullname, Some(signature), filename) + + scope.pushNewScope(methodNode_) + + val parameterNodes = withIndex(parameters(funcDecl)) { (p, i) => + parameterNode(p, i) + } + setVariadic(parameterNodes, funcDecl) + + scope.popScope() + + val stubAst = methodStubAst( + methodNode_, + parameterNodes, + newMethodReturnNode(funcDecl, registerType(returnType)) + ) + val typeDeclAst = + createFunctionTypeAndTypeDecl(funcDecl, methodNode_, name, fullname, signature) + stubAst.merge(typeDeclAst) + else + Ast() + end if + end astForFunctionDeclarator + + protected def astForFunctionDefinition(funcDef: IASTFunctionDefinition): Ast = + val filename = fileName(funcDef) + val returnType = typeForDeclSpecifier(funcDef.getDeclSpecifier) + val name = shortName(funcDef) + val fullname = fullName(funcDef) + val templateParams = templateParameters(funcDef).getOrElse("") + + val signature = + s"$returnType $fullname$templateParams ${parameterListSignature(funcDef)}" + seenFunctionSignatures.add(signature) + + val code = nodeSignature(funcDef) + val methodNode_ = methodNode(funcDef, name, code, fullname, Some(signature), filename) + + methodAstParentStack.push(methodNode_) + scope.pushNewScope(methodNode_) + + val parameterNodes = withIndex(parameters(funcDef)) { (p, i) => + parameterNode(p, i) + } + setVariadic(parameterNodes, funcDef) + + val astForMethod = methodAst( + methodNode_, + parameterNodes.map(Ast(_)), + astForMethodBody(Option(funcDef.getBody)), + newMethodReturnNode(funcDef, registerType(returnType)) ) - case other => - (nodeSignature(other), nodeSignature(other), cleanType(typeForDeclSpecifier(other)), false) - } - - val parameterNode = - parameterInNode(parameter, name, code, paramIndex, variadic, EvaluationStrategies.BY_VALUE, registerType(tpe)) - scope.addToScope(name, (parameterNode, tpe)) - parameterNode - } - - private def astForMethodBody(body: Option[IASTStatement]): Ast = body match { - case Some(b: IASTCompoundStatement) => astForBlockStatement(b) - case Some(b) => astForNode(b) - case None => blockAst(NewBlock()) - } -} + scope.popScope() + methodAstParentStack.pop() + + val typeDeclAst = + createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullname, signature) + astForMethod.merge(typeDeclAst) + end astForFunctionDefinition + + private def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = + val (name, code, tpe, variadic) = parameter match + case p: CASTParameterDeclaration => + ( + ASTStringUtil.getSimpleName(p.getDeclarator.getName), + nodeSignature(p), + cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), + false + ) + case p: CPPASTParameterDeclaration => + ( + ASTStringUtil.getSimpleName(p.getDeclarator.getName), + nodeSignature(p), + cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), + p.getDeclarator.declaresParameterPack() + ) + case s: IASTSimpleDeclaration => + ( + s.getDeclarators.headOption + .map(n => ASTStringUtil.getSimpleName(n.getName)) + .getOrElse(uniqueName("parameter", "", "")._1), + nodeSignature(s), + cleanType(typeForDeclSpecifier(s)), + false + ) + case other => + ( + nodeSignature(other), + nodeSignature(other), + cleanType(typeForDeclSpecifier(other)), + false + ) + + val parameterNode = + parameterInNode( + parameter, + name, + code, + paramIndex, + variadic, + EvaluationStrategies.BY_VALUE, + registerType(tpe) + ) + scope.addToScope(name, (parameterNode, tpe)) + parameterNode + end parameterNode + + private def astForMethodBody(body: Option[IASTStatement]): Ast = body match + case Some(b: IASTCompoundStatement) => astForBlockStatement(b) + case Some(b) => astForNode(b) + case None => blockAst(NewBlock()) +end AstForFunctionsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala index e0aa91cb..f6f79f4d 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala @@ -6,112 +6,118 @@ import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName import org.eclipse.cdt.internal.core.model.ASTStringUtil -trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - protected def astForComment(comment: IASTComment): Ast = - Ast(newCommentNode(comment, nodeSignature(comment), fileName(comment))) - - protected def astForLiteral(lit: IASTLiteralExpression): Ast = { - val tpe = cleanType(ASTTypeUtil.getType(lit.getExpressionType)) - Ast(literalNode(lit, nodeSignature(lit), registerType(tpe))) - } - - protected def astForIdentifier(ident: IASTNode): Ast = { - val identifierName = ident match { - case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) - case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty && id.getBinding != null => id.getBinding.getName - case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty => uniqueName("name", "", "")._1 - case _ => nodeSignature(ident) - } - val variableOption = scope.lookupVariable(identifierName) - val identifierTypeName = variableOption match { - case Some((_, variableTypeName)) => variableTypeName - case None if ident.isInstanceOf[IASTName] && ident.asInstanceOf[IASTName].getBinding != null => - val id = ident.asInstanceOf[IASTName] - id.getBinding match { - case v: IVariable => - v.getType match { - case f: IFunctionType => f.getReturnType.toString - case other => other.toString - } - case other => other.getName - } - case None if ident.isInstanceOf[IASTName] => - typeFor(ident.getParent) - case None => typeFor(ident) - } - - val node = identifierNode(ident, identifierName, nodeSignature(ident), registerType(cleanType(identifierTypeName))) - variableOption match { - case Some((variable, _)) => - Ast(node).withRefEdge(node, variable) - case None => Ast(node) - } - } - - protected def astForFieldReference(fieldRef: IASTFieldReference): Ast = { - val op = if (fieldRef.isPointerDereference) Operators.indirectFieldAccess else Operators.fieldAccess - val ma = callNode(fieldRef, nodeSignature(fieldRef), op, op, DispatchTypes.STATIC_DISPATCH) - val owner = astForExpression(fieldRef.getFieldOwner) - val member = fieldIdentifierNode(fieldRef, fieldRef.getFieldName.toString, fieldRef.getFieldName.toString) - callAst(ma, List(owner, Ast(member))) - } - - protected def astForArrayModifier(arrMod: IASTArrayModifier): Ast = - astForNode(arrMod.getConstantExpression) - - protected def astForInitializerList(l: IASTInitializerList): Ast = { - val op = Operators.arrayInitializer - val initCallNode = callNode(l, nodeSignature(l), op, op, DispatchTypes.STATIC_DISPATCH) - - val MAX_INITIALIZERS = 1000 - val clauses = l.getClauses.slice(0, MAX_INITIALIZERS) - - val args = clauses.toList.map(x => astForNode(x)) - - val ast = callAst(initCallNode, args) - if (l.getClauses.length > MAX_INITIALIZERS) { - val placeholder = - literalNode(l, "", Defines.anyTypeName).argumentIndex(MAX_INITIALIZERS) - ast.withChild(Ast(placeholder)).withArgEdge(initCallNode, placeholder) - } else { - ast - } - } - - protected def astForQualifiedName(qualId: CPPASTQualifiedName): Ast = { - val op = Operators.fieldAccess - val ma = callNode(qualId, nodeSignature(qualId), op, op, DispatchTypes.STATIC_DISPATCH) - - def fieldAccesses(names: List[IASTNode], argIndex: Int = -1): Ast = names match { - case Nil => Ast() - case head :: Nil => - astForNode(head) - case head :: tail => - val code = s"${nodeSignature(head)}::${tail.map(nodeSignature).mkString("::")}" - val callNode_ = - callNode(head, nodeSignature(head), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(argIndex) - callNode_.code = code - val arg1 = astForNode(head) - val arg2 = fieldAccesses(tail) - callAst(callNode_, List(arg1, arg2)) - } - - val qualifier = fieldAccesses(qualId.getQualifier.toIndexedSeq.toList) - - val owner = if (qualifier != Ast()) { - qualifier - } else { - Ast(literalNode(qualId.getLastName, "", Defines.anyTypeName)) - } - - val member = fieldIdentifierNode( - qualId.getLastName, - fixQualifiedName(qualId.getLastName.toString), - qualId.getLastName.toString - ) - callAst(ma, List(owner, Ast(member))) - } - -} +trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForComment(comment: IASTComment): Ast = + Ast(newCommentNode(comment, nodeSignature(comment), fileName(comment))) + + protected def astForLiteral(lit: IASTLiteralExpression): Ast = + val tpe = cleanType(ASTTypeUtil.getType(lit.getExpressionType)) + Ast(literalNode(lit, nodeSignature(lit), registerType(tpe))) + + protected def astForIdentifier(ident: IASTNode): Ast = + val identifierName = ident match + case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) + case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty && id.getBinding != null => + id.getBinding.getName + case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty => + uniqueName("name", "", "")._1 + case _ => nodeSignature(ident) + val variableOption = scope.lookupVariable(identifierName) + val identifierTypeName = variableOption match + case Some((_, variableTypeName)) => variableTypeName + case None + if ident.isInstanceOf[IASTName] && ident.asInstanceOf[ + IASTName + ].getBinding != null => + val id = ident.asInstanceOf[IASTName] + id.getBinding match + case v: IVariable => + v.getType match + case f: IFunctionType => f.getReturnType.toString + case other => other.toString + case other => other.getName + case None if ident.isInstanceOf[IASTName] => + typeFor(ident.getParent) + case None => typeFor(ident) + + val node = identifierNode( + ident, + identifierName, + nodeSignature(ident), + registerType(cleanType(identifierTypeName)) + ) + variableOption match + case Some((variable, _)) => + Ast(node).withRefEdge(node, variable) + case None => Ast(node) + end astForIdentifier + + protected def astForFieldReference(fieldRef: IASTFieldReference): Ast = + val op = if fieldRef.isPointerDereference then Operators.indirectFieldAccess + else Operators.fieldAccess + val ma = callNode(fieldRef, nodeSignature(fieldRef), op, op, DispatchTypes.STATIC_DISPATCH) + val owner = astForExpression(fieldRef.getFieldOwner) + val member = fieldIdentifierNode( + fieldRef, + fieldRef.getFieldName.toString, + fieldRef.getFieldName.toString + ) + callAst(ma, List(owner, Ast(member))) + + protected def astForArrayModifier(arrMod: IASTArrayModifier): Ast = + astForNode(arrMod.getConstantExpression) + + protected def astForInitializerList(l: IASTInitializerList): Ast = + val op = Operators.arrayInitializer + val initCallNode = callNode(l, nodeSignature(l), op, op, DispatchTypes.STATIC_DISPATCH) + + val MAX_INITIALIZERS = 1000 + val clauses = l.getClauses.slice(0, MAX_INITIALIZERS) + + val args = clauses.toList.map(x => astForNode(x)) + + val ast = callAst(initCallNode, args) + if l.getClauses.length > MAX_INITIALIZERS then + val placeholder = + literalNode(l, "", Defines.anyTypeName).argumentIndex( + MAX_INITIALIZERS + ) + ast.withChild(Ast(placeholder)).withArgEdge(initCallNode, placeholder) + else + ast + + protected def astForQualifiedName(qualId: CPPASTQualifiedName): Ast = + val op = Operators.fieldAccess + val ma = callNode(qualId, nodeSignature(qualId), op, op, DispatchTypes.STATIC_DISPATCH) + + def fieldAccesses(names: List[IASTNode], argIndex: Int = -1): Ast = names match + case Nil => Ast() + case head :: Nil => + astForNode(head) + case head :: tail => + val code = s"${nodeSignature(head)}::${tail.map(nodeSignature).mkString("::")}" + val callNode_ = + callNode(head, nodeSignature(head), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(argIndex) + callNode_.code = code + val arg1 = astForNode(head) + val arg2 = fieldAccesses(tail) + callAst(callNode_, List(arg1, arg2)) + + val qualifier = fieldAccesses(qualId.getQualifier.toIndexedSeq.toList) + + val owner = if qualifier != Ast() then + qualifier + else + Ast(literalNode(qualId.getLastName, "", Defines.anyTypeName)) + + val member = fieldIdentifierNode( + qualId.getLastName, + fixQualifiedName(qualId.getLastName.toString), + qualId.getLastName.toString + ) + callAst(ma, List(owner, Ast(member))) + end astForQualifiedName +end AstForPrimitivesCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala index 625f5843..c46c5075 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala @@ -11,252 +11,247 @@ import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIfStatement import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTNamespaceAlias import org.eclipse.cdt.internal.core.model.ASTStringUtil -trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => +trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => - import AstCreatorHelper.OptionSafeAst + import AstCreatorHelper.OptionSafeAst - protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = { - val code = nodeSignature(blockStmt) - val blockCode = if (code == "{}" || code.isEmpty) Defines.empty else code - val node = blockNode(blockStmt, blockCode, registerType(Defines.voidTypeName)) - .order(order) - .argumentIndex(order) - scope.pushNewScope(node) - var currOrder = 1 - val childAsts = blockStmt.getStatements.flatMap { stmt => - val r = astsForStatement(stmt, currOrder) - currOrder = currOrder + r.length - r - } - scope.popScope() - blockAst(node, childAsts.toList) - } + protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = + val code = nodeSignature(blockStmt) + val blockCode = if code == "{}" || code.isEmpty then Defines.empty else code + val node = blockNode(blockStmt, blockCode, registerType(Defines.voidTypeName)) + .order(order) + .argumentIndex(order) + scope.pushNewScope(node) + var currOrder = 1 + val childAsts = blockStmt.getStatements.flatMap { stmt => + val r = astsForStatement(stmt, currOrder) + currOrder = currOrder + r.length + r + } + scope.popScope() + blockAst(node, childAsts.toList) - private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] = - decl.getDeclaration match { - case simplDecl: IASTSimpleDeclaration - if simplDecl.getDeclarators.headOption.exists(_.isInstanceOf[IASTFunctionDeclarator]) => - Seq(astForFunctionDeclarator(simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator])) - case simplDecl: IASTSimpleDeclaration => - val locals = - simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => astForDeclarator(simplDecl, d, i) } - val calls = - simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d => - astForInitializer(d, d.getInitializer) - } - locals ++ calls - case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) - case usingDeclaration: ICPPASTUsingDeclaration => handleUsingDeclaration(usingDeclaration) - case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) - case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) - case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) - case _: ICPPASTUsingDirective => Seq.empty - case decl => Seq(astForNode(decl)) - } + private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] = + decl.getDeclaration match + case simplDecl: IASTSimpleDeclaration + if simplDecl.getDeclarators.headOption.exists( + _.isInstanceOf[IASTFunctionDeclarator] + ) => + Seq(astForFunctionDeclarator( + simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator] + )) + case simplDecl: IASTSimpleDeclaration => + val locals = + simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => + astForDeclarator(simplDecl, d, i) + } + val calls = + simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d => + astForInitializer(d, d.getInitializer) + } + locals ++ calls + case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) + case usingDeclaration: ICPPASTUsingDeclaration => + handleUsingDeclaration(usingDeclaration) + case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) + case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) + case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) + case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) + case _: ICPPASTUsingDirective => Seq.empty + case decl => Seq(astForNode(decl)) - private def astForReturnStatement(ret: IASTReturnStatement): Ast = { - val cpgReturn = returnNode(ret, nodeSignature(ret)) - val expr = nullSafeAst(ret.getReturnValue) - Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root) - } + private def astForReturnStatement(ret: IASTReturnStatement): Ast = + val cpgReturn = returnNode(ret, nodeSignature(ret)) + val expr = nullSafeAst(ret.getReturnValue) + Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root) - private def astForBreakStatement(br: IASTBreakStatement): Ast = { - Ast(controlStructureNode(br, ControlStructureTypes.BREAK, nodeSignature(br))) - } + private def astForBreakStatement(br: IASTBreakStatement): Ast = + Ast(controlStructureNode(br, ControlStructureTypes.BREAK, nodeSignature(br))) - private def astForContinueStatement(cont: IASTContinueStatement): Ast = { - Ast(controlStructureNode(cont, ControlStructureTypes.CONTINUE, nodeSignature(cont))) - } + private def astForContinueStatement(cont: IASTContinueStatement): Ast = + Ast(controlStructureNode(cont, ControlStructureTypes.CONTINUE, nodeSignature(cont))) - private def astForGotoStatement(goto: IASTGotoStatement): Ast = { - val code = s"goto ${ASTStringUtil.getSimpleName(goto.getName)};" - Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) - } + private def astForGotoStatement(goto: IASTGotoStatement): Ast = + val code = s"goto ${ASTStringUtil.getSimpleName(goto.getName)};" + Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) - private def astsForGnuGotoStatement(goto: IGNUASTGotoStatement): Seq[Ast] = { - // This is for GNU GOTO labels as values. - // See: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html - // For such GOTOs we cannot statically determine the target label. As a quick - // hack we simply put edges to all labels found indicated by *. This might be an over-taint. - val code = s"goto *;" - val gotoNode = Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) - val exprNode = nullSafeAst(goto.getLabelNameExpression) - Seq(gotoNode, exprNode) - } + private def astsForGnuGotoStatement(goto: IGNUASTGotoStatement): Seq[Ast] = + // This is for GNU GOTO labels as values. + // See: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html + // For such GOTOs we cannot statically determine the target label. As a quick + // hack we simply put edges to all labels found indicated by *. This might be an over-taint. + val code = s"goto *;" + val gotoNode = Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) + val exprNode = nullSafeAst(goto.getLabelNameExpression) + Seq(gotoNode, exprNode) - private def astsForLabelStatement(label: IASTLabelStatement): Seq[Ast] = { - val cpgLabel = newJumpTargetNode(label) - val nestedStmts = nullSafeAst(label.getNestedStatement) - Ast(cpgLabel) +: nestedStmts - } + private def astsForLabelStatement(label: IASTLabelStatement): Seq[Ast] = + val cpgLabel = newJumpTargetNode(label) + val nestedStmts = nullSafeAst(label.getNestedStatement) + Ast(cpgLabel) +: nestedStmts - private def astForDoStatement(doStmt: IASTDoStatement): Ast = { - val code = nodeSignature(doStmt) - val doNode = controlStructureNode(doStmt, ControlStructureTypes.DO, code) - val conditionAst = astForConditionExpression(doStmt.getCondition) - val bodyAst = nullSafeAst(doStmt.getBody) - controlStructureAst(doNode, Some(conditionAst), bodyAst, placeConditionLast = true) - } + private def astForDoStatement(doStmt: IASTDoStatement): Ast = + val code = nodeSignature(doStmt) + val doNode = controlStructureNode(doStmt, ControlStructureTypes.DO, code) + val conditionAst = astForConditionExpression(doStmt.getCondition) + val bodyAst = nullSafeAst(doStmt.getBody) + controlStructureAst(doNode, Some(conditionAst), bodyAst, placeConditionLast = true) - private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = { - val code = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})" - val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code) - val conditionAst = astForConditionExpression(switchStmt.getControllerExpression) - val stmtAsts = nullSafeAst(switchStmt.getBody) - controlStructureAst(switchNode, Some(conditionAst), stmtAsts) - } + private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = + val code = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})" + val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code) + val conditionAst = astForConditionExpression(switchStmt.getControllerExpression) + val stmtAsts = nullSafeAst(switchStmt.getBody) + controlStructureAst(switchNode, Some(conditionAst), stmtAsts) - private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = { - val labelNode = newJumpTargetNode(caseStmt) - val stmt = astForConditionExpression(caseStmt.getExpression) - Seq(Ast(labelNode), stmt) - } + private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = + val labelNode = newJumpTargetNode(caseStmt) + val stmt = astForConditionExpression(caseStmt.getExpression) + Seq(Ast(labelNode), stmt) - private def astForDefaultStatement(caseStmt: IASTDefaultStatement): Ast = { - Ast(newJumpTargetNode(caseStmt)) - } + private def astForDefaultStatement(caseStmt: IASTDefaultStatement): Ast = + Ast(newJumpTargetNode(caseStmt)) - private def astForTryStatement(tryStmt: ICPPASTTryBlockStatement): Ast = { - val cpgTry = controlStructureNode(tryStmt, ControlStructureTypes.TRY, "try") - val body = nullSafeAst(tryStmt.getTryBody) - // All catches must have order 2 for correct control flow generation. - // TODO fix this. Multiple siblings with the same order are invalid - val catches = tryStmt.getCatchHandlers.flatMap { stmt => - astsForStatement(stmt.getCatchBody, 2) - }.toIndexedSeq - Ast(cpgTry).withChildren(body).withChildren(catches) - } + private def astForTryStatement(tryStmt: ICPPASTTryBlockStatement): Ast = + val cpgTry = controlStructureNode(tryStmt, ControlStructureTypes.TRY, "try") + val body = nullSafeAst(tryStmt.getTryBody) + // All catches must have order 2 for correct control flow generation. + // TODO fix this. Multiple siblings with the same order are invalid + val catches = tryStmt.getCatchHandlers.flatMap { stmt => + astsForStatement(stmt.getCatchBody, 2) + }.toIndexedSeq + Ast(cpgTry).withChildren(body).withChildren(catches) - protected def astsForStatement(statement: IASTStatement, argIndex: Int = -1): Seq[Ast] = { - val r = statement match { - case expr: IASTExpressionStatement => Seq(astForExpression(expr.getExpression)) - case block: IASTCompoundStatement => Seq(astForBlockStatement(block, argIndex)) - case ifStmt: IASTIfStatement => Seq(astForIf(ifStmt)) - case whileStmt: IASTWhileStatement => Seq(astForWhile(whileStmt)) - case forStmt: IASTForStatement => Seq(astForFor(forStmt)) - case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt)) - case doStmt: IASTDoStatement => Seq(astForDoStatement(doStmt)) - case switchStmt: IASTSwitchStatement => Seq(astForSwitchStatement(switchStmt)) - case ret: IASTReturnStatement => Seq(astForReturnStatement(ret)) - case br: IASTBreakStatement => Seq(astForBreakStatement(br)) - case cont: IASTContinueStatement => Seq(astForContinueStatement(cont)) - case goto: IASTGotoStatement => Seq(astForGotoStatement(goto)) - case goto: IGNUASTGotoStatement => astsForGnuGotoStatement(goto) - case defStmt: IASTDefaultStatement => Seq(astForDefaultStatement(defStmt)) - case tryStmt: ICPPASTTryBlockStatement => Seq(astForTryStatement(tryStmt)) - case caseStmt: IASTCaseStatement => astsForCaseStatement(caseStmt) - case decl: IASTDeclarationStatement => astsForDeclarationStatement(decl) - case label: IASTLabelStatement => astsForLabelStatement(label) - case _: IASTNullStatement => Seq.empty - case _ => Seq(astForNode(statement)) - } - r.map(x => asChildOfMacroCall(statement, x)) - } + protected def astsForStatement(statement: IASTStatement, argIndex: Int = -1): Seq[Ast] = + val r = statement match + case expr: IASTExpressionStatement => Seq(astForExpression(expr.getExpression)) + case block: IASTCompoundStatement => Seq(astForBlockStatement(block, argIndex)) + case ifStmt: IASTIfStatement => Seq(astForIf(ifStmt)) + case whileStmt: IASTWhileStatement => Seq(astForWhile(whileStmt)) + case forStmt: IASTForStatement => Seq(astForFor(forStmt)) + case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt)) + case doStmt: IASTDoStatement => Seq(astForDoStatement(doStmt)) + case switchStmt: IASTSwitchStatement => Seq(astForSwitchStatement(switchStmt)) + case ret: IASTReturnStatement => Seq(astForReturnStatement(ret)) + case br: IASTBreakStatement => Seq(astForBreakStatement(br)) + case cont: IASTContinueStatement => Seq(astForContinueStatement(cont)) + case goto: IASTGotoStatement => Seq(astForGotoStatement(goto)) + case goto: IGNUASTGotoStatement => astsForGnuGotoStatement(goto) + case defStmt: IASTDefaultStatement => Seq(astForDefaultStatement(defStmt)) + case tryStmt: ICPPASTTryBlockStatement => Seq(astForTryStatement(tryStmt)) + case caseStmt: IASTCaseStatement => astsForCaseStatement(caseStmt) + case decl: IASTDeclarationStatement => astsForDeclarationStatement(decl) + case label: IASTLabelStatement => astsForLabelStatement(label) + case _: IASTNullStatement => Seq.empty + case _ => Seq(astForNode(statement)) + r.map(x => asChildOfMacroCall(statement, x)) + end astsForStatement - private def astForConditionExpression(expr: IASTExpression, explicitArgumentIndex: Option[Int] = None): Ast = { - val ast = expr match { - case exprList: IASTExpressionList => - val compareAstBlock = blockNode(expr, Defines.empty, registerType(Defines.voidTypeName)) - scope.pushNewScope(compareAstBlock) - val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst) - setArgumentIndices(compareBlockAstChildren) - val compareBlockAst = blockAst(compareAstBlock, compareBlockAstChildren) - scope.popScope() - compareBlockAst - case other => - nullSafeAst(other) - } - explicitArgumentIndex.foreach { i => - ast.root.foreach { case expr: ExpressionNew => expr.argumentIndex = i } - } - ast - } + private def astForConditionExpression( + expr: IASTExpression, + explicitArgumentIndex: Option[Int] = None + ): Ast = + val ast = expr match + case exprList: IASTExpressionList => + val compareAstBlock = + blockNode(expr, Defines.empty, registerType(Defines.voidTypeName)) + scope.pushNewScope(compareAstBlock) + val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst) + setArgumentIndices(compareBlockAstChildren) + val compareBlockAst = blockAst(compareAstBlock, compareBlockAstChildren) + scope.popScope() + compareBlockAst + case other => + nullSafeAst(other) + explicitArgumentIndex.foreach { i => + ast.root.foreach { case expr: ExpressionNew => expr.argumentIndex = i } + } + ast + end astForConditionExpression - private def astForFor(forStmt: IASTForStatement): Ast = { - val codeInit = nullSafeCode(forStmt.getInitializerStatement) - val codeCond = nullSafeCode(forStmt.getConditionExpression) - val codeIter = nullSafeCode(forStmt.getIterationExpression) + private def astForFor(forStmt: IASTForStatement): Ast = + val codeInit = nullSafeCode(forStmt.getInitializerStatement) + val codeCond = nullSafeCode(forStmt.getConditionExpression) + val codeIter = nullSafeCode(forStmt.getIterationExpression) - val code = s"for ($codeInit$codeCond;$codeIter)" - val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) + val code = s"for ($codeInit$codeCond;$codeIter)" + val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName)) - scope.pushNewScope(initAstBlock) - val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList) - scope.popScope() + val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName)) + scope.pushNewScope(initAstBlock) + val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList) + scope.popScope() - val compareAst = astForConditionExpression(forStmt.getConditionExpression, Some(2)) - val updateAst = nullSafeAst(forStmt.getIterationExpression, 3) - val bodyAsts = nullSafeAst(forStmt.getBody, 4) - forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts) - } + val compareAst = astForConditionExpression(forStmt.getConditionExpression, Some(2)) + val updateAst = nullSafeAst(forStmt.getIterationExpression, 3) + val bodyAsts = nullSafeAst(forStmt.getBody, 4) + forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts) - private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = { - val codeDecl = nullSafeCode(forStmt.getDeclaration) - val codeInit = nullSafeCode(forStmt.getInitializerClause) + private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = + val codeDecl = nullSafeCode(forStmt.getDeclaration) + val codeInit = nullSafeCode(forStmt.getInitializerClause) - val code = s"for ($codeDecl:$codeInit)" - val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) + val code = s"for ($codeDecl:$codeInit)" + val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - val initAst = astForNode(forStmt.getInitializerClause) - val declAst = astsForDeclaration(forStmt.getDeclaration) - val stmtAst = nullSafeAst(forStmt.getBody) - controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst) - } + val initAst = astForNode(forStmt.getInitializerClause) + val declAst = astsForDeclaration(forStmt.getDeclaration) + val stmtAst = nullSafeAst(forStmt.getBody) + controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst) - private def astForWhile(whileStmt: IASTWhileStatement): Ast = { - val code = s"while (${nullSafeCode(whileStmt.getCondition)})" - val compareAst = astForConditionExpression(whileStmt.getCondition) - val bodyAst = nullSafeAst(whileStmt.getBody) - whileAst(Some(compareAst), bodyAst, Some(code)) - } + private def astForWhile(whileStmt: IASTWhileStatement): Ast = + val code = s"while (${nullSafeCode(whileStmt.getCondition)})" + val compareAst = astForConditionExpression(whileStmt.getCondition) + val bodyAst = nullSafeAst(whileStmt.getBody) + whileAst(Some(compareAst), bodyAst, Some(code)) - private def astForIf(ifStmt: IASTIfStatement): Ast = { - val (code, conditionAst) = ifStmt match { - case s @ (_: CASTIfStatement | _: CPPASTIfStatement) if s.getConditionExpression != null => - val c = s"if (${nullSafeCode(s.getConditionExpression)})" - val compareAst = astForConditionExpression(s.getConditionExpression) - (c, compareAst) - case s: CPPASTIfStatement if s.getConditionExpression == null => - val c = s"if (${nullSafeCode(s.getConditionDeclaration)})" - val exprBlock = blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(exprBlock) - val a = astsForDeclaration(s.getConditionDeclaration) - setArgumentIndices(a) - scope.popScope() - (c, blockAst(exprBlock, a.toList)) - } + private def astForIf(ifStmt: IASTIfStatement): Ast = + val (code, conditionAst) = ifStmt match + case s @ (_: CASTIfStatement | _: CPPASTIfStatement) + if s.getConditionExpression != null => + val c = s"if (${nullSafeCode(s.getConditionExpression)})" + val compareAst = astForConditionExpression(s.getConditionExpression) + (c, compareAst) + case s: CPPASTIfStatement if s.getConditionExpression == null => + val c = s"if (${nullSafeCode(s.getConditionDeclaration)})" + val exprBlock = + blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(exprBlock) + val a = astsForDeclaration(s.getConditionDeclaration) + setArgumentIndices(a) + scope.popScope() + (c, blockAst(exprBlock, a.toList)) - val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code) + val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code) - val thenAst = ifStmt.getThenClause match { - case block: IASTCompoundStatement => astForBlockStatement(block) - case other if other != null => - val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(thenBlock) - val a = astsForStatement(other) - setArgumentIndices(a) - scope.popScope() - blockAst(thenBlock, a.toList) - case _ => Ast() - } + val thenAst = ifStmt.getThenClause match + case block: IASTCompoundStatement => astForBlockStatement(block) + case other if other != null => + val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(thenBlock) + val a = astsForStatement(other) + setArgumentIndices(a) + scope.popScope() + blockAst(thenBlock, a.toList) + case _ => Ast() - val elseAst = ifStmt.getElseClause match { - case block: IASTCompoundStatement => - val elseNode = controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") - val elseAst = astForBlockStatement(block) - Ast(elseNode).withChild(elseAst) - case other if other != null => - val elseNode = controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") - val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(elseBlock) - val a = astsForStatement(other) - setArgumentIndices(a) - scope.popScope() - Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) - case _ => Ast() - } - controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst)) - } -} + val elseAst = ifStmt.getElseClause match + case block: IASTCompoundStatement => + val elseNode = + controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") + val elseAst = astForBlockStatement(block) + Ast(elseNode).withChild(elseAst) + case other if other != null => + val elseNode = + controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") + val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(elseBlock) + val a = astsForStatement(other) + setArgumentIndices(a) + scope.popScope() + Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) + case _ => Ast() + controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst)) + end astForIf +end AstForStatementsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala index 15c20343..5365f4f8 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala @@ -9,373 +9,464 @@ import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTAliasDeclaration import org.eclipse.cdt.internal.core.model.ASTStringUtil import io.appthreat.x2cpg.datastructures.Stack.* -trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private def parentIsClassDef(node: IASTNode): Boolean = Option(node.getParent) match { - case Some(_: IASTCompositeTypeSpecifier) => true - case _ => false - } - - private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = - nodeSignature(decl).startsWith("typedef") - - protected def templateParameters(e: IASTNode): Option[String] = { - val templateDeclaration = e match { - case _: IASTElaboratedTypeSpecifier | _: IASTFunctionDeclarator | _: IASTCompositeTypeSpecifier - if e.getParent != null => - Option(e.getParent.getParent) - case _: IASTFunctionDefinition if e.getParent != null => Option(e.getParent) - case _ => None - } - - val decl = templateDeclaration.collect { case t: ICPPASTTemplateDeclaration => t } - val templateParams = decl.map(d => ASTStringUtil.getTemplateParameterArray(d.getTemplateParameters)) - templateParams.map(_.mkString("<", ",", ">")) - } - - private def astForNamespaceDefinition(namespaceDefinition: ICPPASTNamespaceDefinition): Ast = { - val (name, fullname) = - uniqueName("namespace", namespaceDefinition.getName.getLastName.toString, fullName(namespaceDefinition)) - val code = nodeSignature(namespaceDefinition) - val cpgNamespace = newNamespaceBlockNode(namespaceDefinition, name, fullname, code, fileName(namespaceDefinition)) - scope.pushNewScope(cpgNamespace) - - val childrenAsts = namespaceDefinition.getDeclarations.flatMap { decl => - val declAsts = astsForDeclaration(decl) - declAsts - }.toIndexedSeq - - val namespaceAst = Ast(cpgNamespace).withChildren(childrenAsts) - scope.popScope() - namespaceAst - } - - protected def astForNamespaceAlias(namespaceAlias: ICPPASTNamespaceAlias): Ast = { - val name = ASTStringUtil.getSimpleName(namespaceAlias.getAlias) - val fullname = fullName(namespaceAlias) - - if (!isQualifiedName(name)) { - usingDeclarationMappings.put(name, fullname) - } - - val code = nodeSignature(namespaceAlias) - val cpgNamespace = newNamespaceBlockNode(namespaceAlias, name, fullname, code, fileName(namespaceAlias)) - Ast(cpgNamespace) - } - - protected def astForDeclarator(declaration: IASTSimpleDeclaration, declarator: IASTDeclarator, index: Int): Ast = { - val name = ASTStringUtil.getSimpleName(declarator.getName) - declaration match { - case d if isTypeDef(d) && shortName(d.getDeclSpecifier).nonEmpty => - val filename = fileName(declaration) - val tpe = registerType(typeFor(declarator)) - Ast(typeDeclNode(declarator, name, registerType(name), filename, nodeSignature(d), alias = Option(tpe))) - case d if parentIsClassDef(d) => - val tpe = declarator match { - case _: IASTArrayDeclarator => registerType(typeFor(declarator)) - case _ => registerType(typeForDeclSpecifier(declaration.getDeclSpecifier)) +trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private def parentIsClassDef(node: IASTNode): Boolean = Option(node.getParent) match + case Some(_: IASTCompositeTypeSpecifier) => true + case _ => false + + private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = + nodeSignature(decl).startsWith("typedef") + + protected def templateParameters(e: IASTNode): Option[String] = + val templateDeclaration = e match + case _: IASTElaboratedTypeSpecifier | _: IASTFunctionDeclarator | _: IASTCompositeTypeSpecifier + if e.getParent != null => + Option(e.getParent.getParent) + case _: IASTFunctionDefinition if e.getParent != null => Option(e.getParent) + case _ => None + + val decl = templateDeclaration.collect { case t: ICPPASTTemplateDeclaration => t } + val templateParams = + decl.map(d => ASTStringUtil.getTemplateParameterArray(d.getTemplateParameters)) + templateParams.map(_.mkString("<", ",", ">")) + + private def astForNamespaceDefinition(namespaceDefinition: ICPPASTNamespaceDefinition): Ast = + val (name, fullname) = + uniqueName( + "namespace", + namespaceDefinition.getName.getLastName.toString, + fullName(namespaceDefinition) + ) + val code = nodeSignature(namespaceDefinition) + val cpgNamespace = newNamespaceBlockNode( + namespaceDefinition, + name, + fullname, + code, + fileName(namespaceDefinition) + ) + scope.pushNewScope(cpgNamespace) + + val childrenAsts = namespaceDefinition.getDeclarations.flatMap { decl => + val declAsts = astsForDeclaration(decl) + declAsts + }.toIndexedSeq + + val namespaceAst = Ast(cpgNamespace).withChildren(childrenAsts) + scope.popScope() + namespaceAst + end astForNamespaceDefinition + + protected def astForNamespaceAlias(namespaceAlias: ICPPASTNamespaceAlias): Ast = + val name = ASTStringUtil.getSimpleName(namespaceAlias.getAlias) + val fullname = fullName(namespaceAlias) + + if !isQualifiedName(name) then + usingDeclarationMappings.put(name, fullname) + + val code = nodeSignature(namespaceAlias) + val cpgNamespace = + newNamespaceBlockNode(namespaceAlias, name, fullname, code, fileName(namespaceAlias)) + Ast(cpgNamespace) + + protected def astForDeclarator( + declaration: IASTSimpleDeclaration, + declarator: IASTDeclarator, + index: Int + ): Ast = + val name = ASTStringUtil.getSimpleName(declarator.getName) + declaration match + case d if isTypeDef(d) && shortName(d.getDeclSpecifier).nonEmpty => + val filename = fileName(declaration) + val tpe = registerType(typeFor(declarator)) + Ast(typeDeclNode( + declarator, + name, + registerType(name), + filename, + nodeSignature(d), + alias = Option(tpe) + )) + case d if parentIsClassDef(d) => + val tpe = declarator match + case _: IASTArrayDeclarator => registerType(typeFor(declarator)) + case _ => registerType(typeForDeclSpecifier(declaration.getDeclSpecifier)) + Ast(memberNode(declarator, name, nodeSignature(declarator), tpe)) + case _ if declarator.isInstanceOf[IASTArrayDeclarator] => + val tpe = registerType(typeFor(declarator)) + val codeTpe = typeFor(declarator, stripKeywords = false) + val node = localNode(declarator, name, s"$codeTpe $name", tpe) + scope.addToScope(name, (node, tpe)) + Ast(node) + case _ => + val tpe = registerType( + cleanType(typeForDeclSpecifier( + declaration.getDeclSpecifier, + stripKeywords = true, + index + )) + ) + val codeTpe = + typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = false, index) + val node = localNode(declarator, name, s"$codeTpe $name", tpe) + scope.addToScope(name, (node, tpe)) + Ast(node) + end match + end astForDeclarator + + protected def astForInitializer(declarator: IASTDeclarator, init: IASTInitializer): Ast = + init match + case i: IASTEqualsInitializer => + val operatorName = Operators.assignment + val callNode_ = + callNode( + declarator, + nodeSignature(declarator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH + ) + val left = astForNode(declarator.getName) + val right = astForNode(i.getInitializerClause) + callAst(callNode_, List(left, right)) + case i: ICPPASTConstructorInitializer => + val name = ASTStringUtil.getSimpleName(declarator.getName) + val callNode_ = callNode( + declarator, + nodeSignature(declarator), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + val args = i.getArguments.toList.map(x => astForNode(x)) + callAst(callNode_, args) + case i: IASTInitializerList => + val operatorName = Operators.assignment + val callNode_ = + callNode( + declarator, + nodeSignature(declarator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH + ) + val left = astForNode(declarator.getName) + val right = astForNode(i) + callAst(callNode_, List(left, right)) + case _ => astForNode(init) + + protected def handleUsingDeclaration(usingDecl: ICPPASTUsingDeclaration): Seq[Ast] = + val simpleName = ASTStringUtil.getSimpleName(usingDecl.getName) + val mappedName = lastNameOfQualifiedName(simpleName) + // we only do the mapping if the declaration is not global because this is already handled by the parser itself + if !isQualifiedName(simpleName) then + usingDecl.getParent match + case ns: ICPPASTNamespaceDefinition => + usingDeclarationMappings.put( + s"${fullName(ns)}.$mappedName", + fixQualifiedName(simpleName) + ) + case _ => + usingDeclarationMappings.put(mappedName, fixQualifiedName(simpleName)) + Seq.empty + + protected def astForAliasDeclaration(aliasDeclaration: ICPPASTAliasDeclaration): Ast = + val name = aliasDeclaration.getAlias.toString + val mappedName = registerType(typeFor(aliasDeclaration.getMappingTypeId)) + val typeDeclNode_ = + typeDeclNode( + aliasDeclaration, + name, + registerType(name), + fileName(aliasDeclaration), + nodeSignature(aliasDeclaration), + alias = Option(mappedName) + ) + Ast(typeDeclNode_) + + protected def astForASMDeclaration(asm: IASTASMDeclaration): Ast = + Ast(unknownNode(asm, nodeSignature(asm))) + + private def astForStructuredBindingDeclaration(decl: ICPPASTStructuredBindingDeclaration): Ast = + val node = blockNode(decl, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val childAsts = decl.getNames.toList.map { name => + astForNode(name) + } + scope.popScope() + setArgumentIndices(childAsts) + blockAst(node, childAsts) + + protected def astsForDeclaration(decl: IASTDeclaration): Seq[Ast] = + val declAsts = decl match + case sb: ICPPASTStructuredBindingDeclaration => + Seq(astForStructuredBindingDeclaration(sb)) + case declaration: IASTSimpleDeclaration => + declaration.getDeclSpecifier match + case spec: IASTCompositeTypeSpecifier => + astsForCompositeType(spec, declaration.getDeclarators.toList) + case spec: IASTEnumerationSpecifier => + astsForEnum(spec, declaration.getDeclarators.toList) + case spec: IASTElaboratedTypeSpecifier => + astsForElaboratedType(spec, declaration.getDeclarators.toList) + case spec: IASTNamedTypeSpecifier if declaration.getDeclarators.isEmpty => + val filename = fileName(spec) + val name = ASTStringUtil.getSimpleName(spec.getName) + Seq(Ast(typeDeclNode( + spec, + name, + registerType(name), + filename, + nodeSignature(spec), + alias = Option(name) + ))) + case _ if declaration.getDeclarators.nonEmpty => + declaration.getDeclarators.toIndexedSeq.zipWithIndex.map { + case (d: IASTFunctionDeclarator, _) => + astForFunctionDeclarator(d) + case (d: IASTSimpleDeclaration, _) if d.getInitializer != null => + Ast() // we do the AST for this down below with initAsts + case (d, i) => + astForDeclarator(declaration, d, i) + } + case _ if nodeSignature(declaration) == ";" => + Seq.empty // dangling decls from unresolved macros; we ignore them + case _ + if declaration.getDeclarators.isEmpty && declaration.getParent.isInstanceOf[ + IASTTranslationUnit + ] => + Seq.empty // dangling decls from unresolved macros; we ignore them + case _ if declaration.getDeclarators.isEmpty => Seq(astForNode(declaration)) + case alias: CPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) + case functDef: IASTFunctionDefinition => Seq(astForFunctionDefinition(functDef)) + case namespaceAlias: ICPPASTNamespaceAlias => Seq(astForNamespaceAlias(namespaceAlias)) + case namespaceDefinition: ICPPASTNamespaceDefinition => + Seq(astForNamespaceDefinition(namespaceDefinition)) + case a: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(a)) + case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) + case t: ICPPASTTemplateDeclaration => astsForDeclaration(t.getDeclaration) + case l: ICPPASTLinkageSpecification => astsForLinkageSpecification(l) + case u: ICPPASTUsingDeclaration => handleUsingDeclaration(u) + case _: ICPPASTVisibilityLabel => Seq.empty + case _: ICPPASTUsingDirective => Seq.empty + case _: ICPPASTExplicitTemplateInstantiation => Seq.empty + case _ => Seq(astForNode(decl)) + + val initAsts = decl match + case declaration: IASTSimpleDeclaration if declaration.getDeclarators.nonEmpty => + declaration.getDeclarators.toList.map { + case d: IASTDeclarator if d.getInitializer != null => + astForInitializer(d, d.getInitializer) + case arrayDecl: IASTArrayDeclarator => + val op = Operators.arrayInitializer + val initCallNode = callNode( + arrayDecl, + nodeSignature(arrayDecl), + op, + op, + DispatchTypes.STATIC_DISPATCH + ) + val initArgs = + arrayDecl.getArrayModifiers.toList.filter(m => + m.getConstantExpression != null + ).map(astForNode) + callAst(initCallNode, initArgs) + case _ => Ast() + } + case _ => Nil + declAsts ++ initAsts + end astsForDeclaration + + private def astsForLinkageSpecification(l: ICPPASTLinkageSpecification): Seq[Ast] = + l.getDeclarations.toList.flatMap { d => + astsForDeclaration(d) } - Ast(memberNode(declarator, name, nodeSignature(declarator), tpe)) - case _ if declarator.isInstanceOf[IASTArrayDeclarator] => - val tpe = registerType(typeFor(declarator)) - val codeTpe = typeFor(declarator, stripKeywords = false) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) - scope.addToScope(name, (node, tpe)) - Ast(node) - case _ => - val tpe = registerType( - cleanType(typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = true, index)) + + private def astsForCompositeType( + typeSpecifier: IASTCompositeTypeSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) + } + + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val fullname = registerType(cleanType(fullName(typeSpecifier))) + val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) match + case n if n.isEmpty => lastNameOfQualifiedName(fullname) + case other => other + val code = nodeSignature(typeSpecifier) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = + templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) + val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption + + val typeDecl = typeSpecifier match + case cppClass: ICPPASTCompositeTypeSpecifier => + val baseClassList = + cppClass.getBaseSpecifiers.toSeq.map(s => + registerType(s.getNameSpecifier.toString) + ) + typeDeclNode( + typeSpecifier, + name, + fullname, + filename, + code, + inherits = baseClassList, + alias = alias + ) + case _ => + typeDeclNode(typeSpecifier, name, fullname, filename, code, alias = alias) + + methodAstParentStack.push(typeDecl) + scope.pushNewScope(typeDecl) + + val memberAsts = typeSpecifier.getDeclarations(true).toList.flatMap(astsForDeclaration) + + methodAstParentStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast(typeDecl).withChildren(member) +: declAsts + else + val init = staticInitMethodAst( + calls, + s"$fullname:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.anyTypeName, + Some(filename), + lineNumber, + columnNumber + ) + Ast(typeDecl).withChildren(member).withChild(init) +: declAsts + end astsForCompositeType + + private def astsForElaboratedType( + typeSpecifier: IASTElaboratedTypeSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) + } + + val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) + val fullname = registerType(cleanType(fullName(typeSpecifier))) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = + templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) + val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption + + val typeDecl = + typeDeclNode( + typeSpecifier, + name, + fullname, + filename, + nodeSignature(typeSpecifier), + alias = alias + ) + + Ast(typeDecl) +: declAsts + end astsForElaboratedType + + private def astsForEnumerator(enumerator: IASTEnumerationSpecifier.IASTEnumerator): Seq[Ast] = + val tpe = enumerator.getParent match + case enumeration: ICPPASTEnumerationSpecifier if enumeration.getBaseType != null => + enumeration.getBaseType.toString + case _ => typeFor(enumerator) + val cpgMember = memberNode( + enumerator, + ASTStringUtil.getSimpleName(enumerator.getName), + nodeSignature(enumerator), + registerType(cleanType(tpe)) ) - val codeTpe = typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = false, index) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) - scope.addToScope(name, (node, tpe)) - Ast(node) - } - - } - - protected def astForInitializer(declarator: IASTDeclarator, init: IASTInitializer): Ast = init match { - case i: IASTEqualsInitializer => - val operatorName = Operators.assignment - val callNode_ = - callNode(declarator, nodeSignature(declarator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) - val left = astForNode(declarator.getName) - val right = astForNode(i.getInitializerClause) - callAst(callNode_, List(left, right)) - case i: ICPPASTConstructorInitializer => - val name = ASTStringUtil.getSimpleName(declarator.getName) - val callNode_ = callNode(declarator, nodeSignature(declarator), name, name, DispatchTypes.STATIC_DISPATCH) - val args = i.getArguments.toList.map(x => astForNode(x)) - callAst(callNode_, args) - case i: IASTInitializerList => - val operatorName = Operators.assignment - val callNode_ = - callNode(declarator, nodeSignature(declarator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) - val left = astForNode(declarator.getName) - val right = astForNode(i) - callAst(callNode_, List(left, right)) - case _ => astForNode(init) - } - - protected def handleUsingDeclaration(usingDecl: ICPPASTUsingDeclaration): Seq[Ast] = { - val simpleName = ASTStringUtil.getSimpleName(usingDecl.getName) - val mappedName = lastNameOfQualifiedName(simpleName) - // we only do the mapping if the declaration is not global because this is already handled by the parser itself - if (!isQualifiedName(simpleName)) { - usingDecl.getParent match { - case ns: ICPPASTNamespaceDefinition => - usingDeclarationMappings.put(s"${fullName(ns)}.$mappedName", fixQualifiedName(simpleName)) - case _ => - usingDeclarationMappings.put(mappedName, fixQualifiedName(simpleName)) - } - } - Seq.empty - } - - protected def astForAliasDeclaration(aliasDeclaration: ICPPASTAliasDeclaration): Ast = { - val name = aliasDeclaration.getAlias.toString - val mappedName = registerType(typeFor(aliasDeclaration.getMappingTypeId)) - val typeDeclNode_ = - typeDeclNode( - aliasDeclaration, - name, - registerType(name), - fileName(aliasDeclaration), - nodeSignature(aliasDeclaration), - alias = Option(mappedName) - ) - Ast(typeDeclNode_) - } - - protected def astForASMDeclaration(asm: IASTASMDeclaration): Ast = Ast(unknownNode(asm, nodeSignature(asm))) - - private def astForStructuredBindingDeclaration(decl: ICPPASTStructuredBindingDeclaration): Ast = { - val node = blockNode(decl, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val childAsts = decl.getNames.toList.map { name => - astForNode(name) - } - scope.popScope() - setArgumentIndices(childAsts) - blockAst(node, childAsts) - } - - protected def astsForDeclaration(decl: IASTDeclaration): Seq[Ast] = { - val declAsts = decl match { - case sb: ICPPASTStructuredBindingDeclaration => Seq(astForStructuredBindingDeclaration(sb)) - case declaration: IASTSimpleDeclaration => - declaration.getDeclSpecifier match { - case spec: IASTCompositeTypeSpecifier => - astsForCompositeType(spec, declaration.getDeclarators.toList) - case spec: IASTEnumerationSpecifier => - astsForEnum(spec, declaration.getDeclarators.toList) - case spec: IASTElaboratedTypeSpecifier => - astsForElaboratedType(spec, declaration.getDeclarators.toList) - case spec: IASTNamedTypeSpecifier if declaration.getDeclarators.isEmpty => - val filename = fileName(spec) - val name = ASTStringUtil.getSimpleName(spec.getName) - Seq(Ast(typeDeclNode(spec, name, registerType(name), filename, nodeSignature(spec), alias = Option(name)))) - case _ if declaration.getDeclarators.nonEmpty => - declaration.getDeclarators.toIndexedSeq.zipWithIndex.map { - case (d: IASTFunctionDeclarator, _) => - astForFunctionDeclarator(d) - case (d: IASTSimpleDeclaration, _) if d.getInitializer != null => - Ast() // we do the AST for this down below with initAsts - case (d, i) => - astForDeclarator(declaration, d, i) - } - case _ if nodeSignature(declaration) == ";" => - Seq.empty // dangling decls from unresolved macros; we ignore them - case _ if declaration.getDeclarators.isEmpty && declaration.getParent.isInstanceOf[IASTTranslationUnit] => - Seq.empty // dangling decls from unresolved macros; we ignore them - case _ if declaration.getDeclarators.isEmpty => Seq(astForNode(declaration)) + + if enumerator.getValue != null then + val operatorName = Operators.assignment + val callNode_ = + callNode( + enumerator, + nodeSignature(enumerator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH + ) + val left = astForNode(enumerator.getName) + val right = astForNode(enumerator.getValue) + val ast = callAst(callNode_, List(left, right)) + Seq(Ast(cpgMember), ast) + else + Seq(Ast(cpgMember)) + end astsForEnumerator + + private def astsForEnum( + typeSpecifier: IASTEnumerationSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) } - case alias: CPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case functDef: IASTFunctionDefinition => Seq(astForFunctionDefinition(functDef)) - case namespaceAlias: ICPPASTNamespaceAlias => Seq(astForNamespaceAlias(namespaceAlias)) - case namespaceDefinition: ICPPASTNamespaceDefinition => Seq(astForNamespaceDefinition(namespaceDefinition)) - case a: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(a)) - case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) - case t: ICPPASTTemplateDeclaration => astsForDeclaration(t.getDeclaration) - case l: ICPPASTLinkageSpecification => astsForLinkageSpecification(l) - case u: ICPPASTUsingDeclaration => handleUsingDeclaration(u) - case _: ICPPASTVisibilityLabel => Seq.empty - case _: ICPPASTUsingDirective => Seq.empty - case _: ICPPASTExplicitTemplateInstantiation => Seq.empty - case _ => Seq(astForNode(decl)) - } - - val initAsts = decl match { - case declaration: IASTSimpleDeclaration if declaration.getDeclarators.nonEmpty => - declaration.getDeclarators.toList.map { - case d: IASTDeclarator if d.getInitializer != null => - astForInitializer(d, d.getInitializer) - case arrayDecl: IASTArrayDeclarator => - val op = Operators.arrayInitializer - val initCallNode = callNode(arrayDecl, nodeSignature(arrayDecl), op, op, DispatchTypes.STATIC_DISPATCH) - val initArgs = - arrayDecl.getArrayModifiers.toList.filter(m => m.getConstantExpression != null).map(astForNode) - callAst(initCallNode, initArgs) - case _ => Ast() + + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val (name, fullname) = + uniqueName( + "enum", + ASTStringUtil.getSimpleName(typeSpecifier.getName), + fullName(typeSpecifier) + ) + val alias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + + val (deAliasedName, deAliasedFullName, newAlias) = + if name.contains("anonymous_enum") && alias.isDefined then + ( + alias.get, + fullname.substring(0, fullname.indexOf("anonymous_enum")) + alias.get, + None + ) + else (name, fullname, alias) + + val typeDecl = + typeDeclNode( + typeSpecifier, + deAliasedName, + registerType(deAliasedFullName), + filename, + nodeSignature(typeSpecifier), + alias = newAlias + ) + methodAstParentStack.push(typeDecl) + scope.pushNewScope(typeDecl) + + val memberAsts = typeSpecifier.getEnumerators.toList.flatMap { e => + astsForEnumerator(e) } - case _ => Nil - } - declAsts ++ initAsts - } - - private def astsForLinkageSpecification(l: ICPPASTLinkageSpecification): Seq[Ast] = - l.getDeclarations.toList.flatMap { d => - astsForDeclaration(d) - } - - private def astsForCompositeType(typeSpecifier: IASTCompositeTypeSpecifier, decls: List[IASTDeclarator]): Seq[Ast] = { - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) match { - case n if n.isEmpty => lastNameOfQualifiedName(fullname) - case other => other - } - val code = nodeSignature(typeSpecifier) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption - - val typeDecl = typeSpecifier match { - case cppClass: ICPPASTCompositeTypeSpecifier => - val baseClassList = - cppClass.getBaseSpecifiers.toSeq.map(s => registerType(s.getNameSpecifier.toString)) - typeDeclNode(typeSpecifier, name, fullname, filename, code, inherits = baseClassList, alias = alias) - case _ => - typeDeclNode(typeSpecifier, name, fullname, filename, code, alias = alias) - } - - methodAstParentStack.push(typeDecl) - scope.pushNewScope(typeDecl) - - val memberAsts = typeSpecifier.getDeclarations(true).toList.flatMap(astsForDeclaration) - - methodAstParentStack.pop() - scope.popScope() - - val (calls, member) = memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if (calls.isEmpty) { - Ast(typeDecl).withChildren(member) +: declAsts - } else { - val init = staticInitMethodAst( - calls, - s"$fullname:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.anyTypeName, - Some(filename), - lineNumber, - columnNumber - ) - Ast(typeDecl).withChildren(member).withChild(init) +: declAsts - } - } - - private def astsForElaboratedType( - typeSpecifier: IASTElaboratedTypeSpecifier, - decls: List[IASTDeclarator] - ): Seq[Ast] = { - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption - - val typeDecl = - typeDeclNode(typeSpecifier, name, fullname, filename, nodeSignature(typeSpecifier), alias = alias) - - Ast(typeDecl) +: declAsts - } - - private def astsForEnumerator(enumerator: IASTEnumerationSpecifier.IASTEnumerator): Seq[Ast] = { - val tpe = enumerator.getParent match { - case enumeration: ICPPASTEnumerationSpecifier if enumeration.getBaseType != null => - enumeration.getBaseType.toString - case _ => typeFor(enumerator) - } - val cpgMember = memberNode( - enumerator, - ASTStringUtil.getSimpleName(enumerator.getName), - nodeSignature(enumerator), - registerType(cleanType(tpe)) - ) - - if (enumerator.getValue != null) { - val operatorName = Operators.assignment - val callNode_ = - callNode(enumerator, nodeSignature(enumerator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH) - val left = astForNode(enumerator.getName) - val right = astForNode(enumerator.getValue) - val ast = callAst(callNode_, List(left, right)) - Seq(Ast(cpgMember), ast) - } else { - Seq(Ast(cpgMember)) - } - } - - private def astsForEnum(typeSpecifier: IASTEnumerationSpecifier, decls: List[IASTDeclarator]): Seq[Ast] = { - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val (name, fullname) = - uniqueName("enum", ASTStringUtil.getSimpleName(typeSpecifier.getName), fullName(typeSpecifier)) - val alias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - - val (deAliasedName, deAliasedFullName, newAlias) = if (name.contains("anonymous_enum") && alias.isDefined) { - (alias.get, fullname.substring(0, fullname.indexOf("anonymous_enum")) + alias.get, None) - } else { (name, fullname, alias) } - - val typeDecl = - typeDeclNode( - typeSpecifier, - deAliasedName, - registerType(deAliasedFullName), - filename, - nodeSignature(typeSpecifier), - alias = newAlias - ) - methodAstParentStack.push(typeDecl) - scope.pushNewScope(typeDecl) - - val memberAsts = typeSpecifier.getEnumerators.toList.flatMap { e => - astsForEnumerator(e) - } - methodAstParentStack.pop() - scope.popScope() - - val (calls, member) = memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if (calls.isEmpty) { - Ast(typeDecl).withChildren(member) +: declAsts - } else { - val init = staticInitMethodAst( - calls, - s"$deAliasedFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.anyTypeName, - Some(filename), - lineNumber, - columnNumber - ) - Ast(typeDecl).withChildren(member).withChild(init) +: declAsts - } - } - -} + methodAstParentStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast(typeDecl).withChildren(member) +: declAsts + else + val init = staticInitMethodAst( + calls, + s"$deAliasedFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.anyTypeName, + Some(filename), + lineNumber, + columnNumber + ) + Ast(typeDecl).withChildren(member).withChild(init) +: declAsts + end astsForEnum +end AstForTypesCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala index eb4930da..c7be9669 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala @@ -1,49 +1,45 @@ package io.appthreat.c2cpg.astcreation -import io.appthreat.x2cpg.utils.NodeBuilders.{newMethodReturnNode => newMethodReturnNode_} -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.appthreat.x2cpg.utils.NodeBuilders.{newMethodReturnNode as newMethodReturnNode_} +import io.shiftleft.codepropertygraph.generated.nodes.* import org.eclipse.cdt.core.dom.ast.{IASTLabelStatement, IASTNode} import org.eclipse.cdt.core.dom.ast.IASTPreprocessorIncludeStatement import org.eclipse.cdt.internal.core.model.ASTStringUtil -trait AstNodeBuilder { this: AstCreator => - protected def newCommentNode(node: IASTNode, code: String, filename: String): NewComment = { - NewComment().code(code).filename(filename).lineNumber(line(node)).columnNumber(column(node)) - } +trait AstNodeBuilder: + this: AstCreator => + protected def newCommentNode(node: IASTNode, code: String, filename: String): NewComment = + NewComment().code(code).filename(filename).lineNumber(line(node)).columnNumber(column(node)) - protected def newNamespaceBlockNode( - node: IASTNode, - name: String, - fullname: String, - code: String, - filename: String - ): NewNamespaceBlock = { - NewNamespaceBlock() - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - .filename(filename) - .name(name) - .fullName(fullname) - } + protected def newNamespaceBlockNode( + node: IASTNode, + name: String, + fullname: String, + code: String, + filename: String + ): NewNamespaceBlock = + NewNamespaceBlock() + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) + .filename(filename) + .name(name) + .fullName(fullname) - // TODO: We should get rid of this method as its being used at multiple places and use it from x2cpg/AstNodeBuilder "methodReturnNode" - protected def newMethodReturnNode(node: IASTNode, typeFullName: String): NewMethodReturn = { - newMethodReturnNode_(typeFullName, None, line(node), column(node)) - } + // TODO: We should get rid of this method as its being used at multiple places and use it from x2cpg/AstNodeBuilder "methodReturnNode" + protected def newMethodReturnNode(node: IASTNode, typeFullName: String): NewMethodReturn = + newMethodReturnNode_(typeFullName, None, line(node), column(node)) - protected def newJumpTargetNode(node: IASTNode): NewJumpTarget = { - val code = nodeSignature(node) - val name = node match { - case label: IASTLabelStatement => ASTStringUtil.getSimpleName(label.getName) - case _ if code.startsWith("case") => "case" - case _ => "default" - } - NewJumpTarget() - .parserTypeName(node.getClass.getSimpleName) - .name(name) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } -} + protected def newJumpTargetNode(node: IASTNode): NewJumpTarget = + val code = nodeSignature(node) + val name = node match + case label: IASTLabelStatement => ASTStringUtil.getSimpleName(label.getName) + case _ if code.startsWith("case") => "case" + case _ => "default" + NewJumpTarget() + .parserTypeName(node.getClass.getSimpleName) + .name(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) +end AstNodeBuilder diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala index e63cad99..32705c20 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala @@ -1,8 +1,7 @@ package io.appthreat.c2cpg.astcreation -object Defines { - val anyTypeName: String = "ANY" - val voidTypeName: String = "void" - val qualifiedNameSeparator: String = "::" - val empty = "" -} +object Defines: + val anyTypeName: String = "ANY" + val voidTypeName: String = "void" + val qualifiedNameSeparator: String = "::" + val empty = "" diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala index 377f3781..25a0aacf 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala @@ -2,184 +2,193 @@ package io.appthreat.c2cpg.astcreation import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.{ - AstNodeNew, - ExpressionNew, - NewBlock, - NewCall, - NewFieldIdentifier, - NewNode + AstNodeNew, + ExpressionNew, + NewBlock, + NewCall, + NewFieldIdentifier, + NewNode } import io.appthreat.x2cpg.{Ast, AstEdge, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.NewLocal import org.apache.commons.lang.StringUtils -import org.eclipse.cdt.core.dom.ast.{IASTMacroExpansionLocation, IASTNode, IASTPreprocessorMacroDefinition} +import org.eclipse.cdt.core.dom.ast.{ + IASTMacroExpansionLocation, + IASTNode, + IASTPreprocessorMacroDefinition +} import org.eclipse.cdt.core.dom.ast.IASTBinaryExpression import org.eclipse.cdt.internal.core.model.ASTStringUtil import scala.annotation.nowarn import scala.collection.mutable -trait MacroHandler(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private val nodeOffsetMacroPairs: mutable.Stack[(Int, IASTPreprocessorMacroDefinition)] = { - mutable.Stack.from( - cdtAst.getNodeLocations.toList - .collect { case exp: IASTMacroExpansionLocation => - (exp.asFileLocation().getNodeOffset, exp.getExpansion.getMacroDefinition) +trait MacroHandler(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private val nodeOffsetMacroPairs: mutable.Stack[(Int, IASTPreprocessorMacroDefinition)] = + mutable.Stack.from( + cdtAst.getNodeLocations.toList + .collect { case exp: IASTMacroExpansionLocation => + (exp.asFileLocation().getNodeOffset, exp.getExpansion.getMacroDefinition) + } + .sortBy(_._1) + ) + + /** For the given node, determine if it is expanded from a macro, and if so, create a Call node + * to represent the macro invocation and attach `ast` as its child. + */ + def asChildOfMacroCall(node: IASTNode, ast: Ast): Ast = + // If a macro in a header file contained a method definition already seen in some + // source file we skipped that during the previous AST creation and returned an empty AST. + if ast.root.isEmpty && isExpandedFromMacro(node) then return ast + // We do nothing for locals only. + if ast.nodes.size == 1 && ast.root.exists(_.isInstanceOf[NewLocal]) then return ast + // Otherwise, we create the synthetic call AST. + val matchingMacro = extractMatchingMacro(node) + val macroCallAst = matchingMacro.map { case (mac, args) => + createMacroCallAst(ast, node, mac, args) } - .sortBy(_._1) - ) - } - - /** For the given node, determine if it is expanded from a macro, and if so, create a Call node to represent the macro - * invocation and attach `ast` as its child. - */ - def asChildOfMacroCall(node: IASTNode, ast: Ast): Ast = { - // If a macro in a header file contained a method definition already seen in some - // source file we skipped that during the previous AST creation and returned an empty AST. - if (ast.root.isEmpty && isExpandedFromMacro(node)) return ast - // We do nothing for locals only. - if (ast.nodes.size == 1 && ast.root.exists(_.isInstanceOf[NewLocal])) return ast - // Otherwise, we create the synthetic call AST. - val matchingMacro = extractMatchingMacro(node) - val macroCallAst = matchingMacro.map { case (mac, args) => createMacroCallAst(ast, node, mac, args) } - macroCallAst match { - case Some(callAst) => - val lostLocals = ast.refEdges.collect { case AstEdge(_, dst: NewLocal) => Ast(dst) }.toList - val newAst = ast.subTreeCopy(ast.root.get.asInstanceOf[AstNodeNew], argIndex = 1) - // We need to wrap the copied AST as it may contain CPG nodes not being allowed - // to be connected via AST edges under a CALL. E.g., LOCALs but only if its not already a BLOCK. - val childAst = newAst.root match { - case Some(_: NewBlock) => - newAst - case _ => - val b = NewBlock().argumentIndex(1).typeFullName(registerType(Defines.voidTypeName)) - blockAst(b, List(newAst)) + macroCallAst match + case Some(callAst) => + val lostLocals = ast.refEdges.collect { case AstEdge(_, dst: NewLocal) => + Ast(dst) + }.toList + val newAst = ast.subTreeCopy(ast.root.get.asInstanceOf[AstNodeNew], argIndex = 1) + // We need to wrap the copied AST as it may contain CPG nodes not being allowed + // to be connected via AST edges under a CALL. E.g., LOCALs but only if its not already a BLOCK. + val childAst = newAst.root match + case Some(_: NewBlock) => + newAst + case _ => + val b = NewBlock().argumentIndex(1).typeFullName( + registerType(Defines.voidTypeName) + ) + blockAst(b, List(newAst)) + callAst.withChildren(lostLocals).withChild(childAst) + case None => ast + end asChildOfMacroCall + + /** For the given node, determine if it is expanded from a macro, and if so, find the first + * matching (offset, macro) pair in nodeOffsetMacroPairs, removing non-matching elements from + * the start of nodeOffsetMacroPairs. Returns (Some(macroDefinition, arguments)) if a macro + * definition matches and None otherwise. + */ + private def extractMatchingMacro(node: IASTNode) + : Option[(IASTPreprocessorMacroDefinition, List[String])] = + val expansionLocations = + expandedFromMacro(node).filterNot(isExpandedFrom(node.getParent, _)) + val nodeOffset = node.getFileLocation.getNodeOffset + var matchingMacro = Option.empty[(IASTPreprocessorMacroDefinition, List[String])] + + expansionLocations.foreach { macroLocation => + while matchingMacro.isEmpty && nodeOffsetMacroPairs.headOption.exists( + _._1 <= nodeOffset + ) + do + val (_, macroDefinition) = nodeOffsetMacroPairs.pop() + val macroExpansionName = ASTStringUtil.getSimpleName( + macroLocation.getExpansion.getMacroDefinition.getName + ) + val macroDefinitionName = ASTStringUtil.getSimpleName(macroDefinition.getName) + if macroExpansionName == macroDefinitionName then + matchingMacro = Option((macroDefinition, List[String]())) } - callAst.withChildren(lostLocals).withChild(childAst) - case None => ast - } - } - - /** For the given node, determine if it is expanded from a macro, and if so, find the first matching (offset, macro) - * pair in nodeOffsetMacroPairs, removing non-matching elements from the start of nodeOffsetMacroPairs. Returns - * (Some(macroDefinition, arguments)) if a macro definition matches and None otherwise. - */ - private def extractMatchingMacro(node: IASTNode): Option[(IASTPreprocessorMacroDefinition, List[String])] = { - val expansionLocations = expandedFromMacro(node).filterNot(isExpandedFrom(node.getParent, _)) - val nodeOffset = node.getFileLocation.getNodeOffset - var matchingMacro = Option.empty[(IASTPreprocessorMacroDefinition, List[String])] - - expansionLocations.foreach { macroLocation => - while (matchingMacro.isEmpty && nodeOffsetMacroPairs.headOption.exists(_._1 <= nodeOffset)) { - val (_, macroDefinition) = nodeOffsetMacroPairs.pop() - val macroExpansionName = ASTStringUtil.getSimpleName(macroLocation.getExpansion.getMacroDefinition.getName) - val macroDefinitionName = ASTStringUtil.getSimpleName(macroDefinition.getName) - if (macroExpansionName == macroDefinitionName) { - matchingMacro = Option((macroDefinition, List[String]())) + + matchingMacro + end extractMatchingMacro + + /** Determine whether `node` is expanded from the macro expansion at `loc`. + */ + private def isExpandedFrom(node: IASTNode, loc: IASTMacroExpansionLocation): Boolean = + expandedFromMacro(node).map(_.getExpansion.getMacroDefinition).contains( + loc.getExpansion.getMacroDefinition + ) + + private def argumentTrees(arguments: List[String], ast: Ast): List[Option[Ast]] = + arguments.zipWithIndex.map { case (arg, i) => + val rootNode = argForCode(arg, ast) + rootNode.map(x => ast.subTreeCopy(x.asInstanceOf[AstNodeNew], i + 1)) } - } - } - - matchingMacro - } - - /** Determine whether `node` is expanded from the macro expansion at `loc`. - */ - private def isExpandedFrom(node: IASTNode, loc: IASTMacroExpansionLocation): Boolean = - expandedFromMacro(node).map(_.getExpansion.getMacroDefinition).contains(loc.getExpansion.getMacroDefinition) - - private def argumentTrees(arguments: List[String], ast: Ast): List[Option[Ast]] = { - arguments.zipWithIndex.map { case (arg, i) => - val rootNode = argForCode(arg, ast) - rootNode.map(x => ast.subTreeCopy(x.asInstanceOf[AstNodeNew], i + 1)) - } - } - - private def argForCode(code: String, ast: Ast): Option[NewNode] = { - val normalizedCode = code.replace(" ", "") - if (normalizedCode == "") { - None - } else { - ast.nodes.collectFirst { - case x: ExpressionNew if !x.isInstanceOf[NewFieldIdentifier] && x.code == normalizedCode => x - } - } - } - - /** Create an AST that represents a macro expansion as a call. The AST is rooted in a CALL node and contains sub trees - * for arguments. These are also connected to the AST via ARGUMENT edges. We include line number information in the - * CALL node that is picked up by the MethodStubCreator. - */ - private def createMacroCallAst( - ast: Ast, - node: IASTNode, - macroDef: IASTPreprocessorMacroDefinition, - arguments: List[String] - ): Ast = { - val name = ASTStringUtil.getSimpleName(macroDef.getName) - val code = node.getRawSignature.stripSuffix(";") - val argAsts = argumentTrees(arguments, ast).map(_.getOrElse(Ast())) - - val callName = StringUtils.normalizeSpace(name) - val callFullName = StringUtils.normalizeSpace(fullName(macroDef, argAsts)) - val callNode = - NewCall() - .name(callName) - .dispatchType(DispatchTypes.INLINED) - .methodFullName(callFullName) - .code(code) - .typeFullName(typeFor(node)) - .lineNumber(line(node)) - .columnNumber(column(node)) - callAst(callNode, argAsts) - } - - /** Create a full name field that encodes line information that can be picked up by the MethodStubCreator in order to - * create a METHOD node with the correct location information. - */ - private def fullName(macroDef: IASTPreprocessorMacroDefinition, argAsts: List[Ast]) = { - val name = ASTStringUtil.getSimpleName(macroDef.getName) - val filename = fileName(macroDef) - val lineNo: Integer = line(macroDef).getOrElse(-1) - val lineNoEnd: Integer = lineEnd(macroDef).getOrElse(-1) - s"$filename:$lineNo:$lineNoEnd:$name:${argAsts.size}" - } - - /** The CDT utility method is unfortunately in a class that is marked as deprecated, however, this is because the CDT - * team would like to discourage its use but at the same time does not plan to remove this code. - */ - @nowarn - def nodeSignature(node: IASTNode): String = { - import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature - val sig = if (isExpandedFromMacro(node)) { - val sig = getNodeSignature(node) - if (sig.isEmpty) { - node.getRawSignature - } else { - sig - } - } else { - node.getRawSignature - } - shortenCode(sig) - } - - private def isExpandedFromMacro(node: IASTNode): Boolean = expandedFromMacro(node).nonEmpty - - private def expandedFromMacro(node: IASTNode): Option[IASTMacroExpansionLocation] = { - val locations = node.getNodeLocations.toList - val locationsSorted = node match { - // For binary expressions the expansion locations may occur in any order. - // We manually sort them here to ignore this. - // TODO: This may also happen with other expressions that allow for multiple sub elements. - case _: IASTBinaryExpression => locations.sortBy(_.isInstanceOf[IASTMacroExpansionLocation]) - case _ => locations - } - locationsSorted match { - case (head: IASTMacroExpansionLocation) :: _ => Option(head) - case _ => None - } - } -} + private def argForCode(code: String, ast: Ast): Option[NewNode] = + val normalizedCode = code.replace(" ", "") + if normalizedCode == "" then + None + else + ast.nodes.collectFirst { + case x: ExpressionNew + if !x.isInstanceOf[NewFieldIdentifier] && x.code == normalizedCode => x + } + + /** Create an AST that represents a macro expansion as a call. The AST is rooted in a CALL node + * and contains sub trees for arguments. These are also connected to the AST via ARGUMENT + * edges. We include line number information in the CALL node that is picked up by the + * MethodStubCreator. + */ + private def createMacroCallAst( + ast: Ast, + node: IASTNode, + macroDef: IASTPreprocessorMacroDefinition, + arguments: List[String] + ): Ast = + val name = ASTStringUtil.getSimpleName(macroDef.getName) + val code = node.getRawSignature.stripSuffix(";") + val argAsts = argumentTrees(arguments, ast).map(_.getOrElse(Ast())) + + val callName = StringUtils.normalizeSpace(name) + val callFullName = StringUtils.normalizeSpace(fullName(macroDef, argAsts)) + val callNode = + NewCall() + .name(callName) + .dispatchType(DispatchTypes.INLINED) + .methodFullName(callFullName) + .code(code) + .typeFullName(typeFor(node)) + .lineNumber(line(node)) + .columnNumber(column(node)) + callAst(callNode, argAsts) + end createMacroCallAst + + /** Create a full name field that encodes line information that can be picked up by the + * MethodStubCreator in order to create a METHOD node with the correct location information. + */ + private def fullName(macroDef: IASTPreprocessorMacroDefinition, argAsts: List[Ast]) = + val name = ASTStringUtil.getSimpleName(macroDef.getName) + val filename = fileName(macroDef) + val lineNo: Integer = line(macroDef).getOrElse(-1) + val lineNoEnd: Integer = lineEnd(macroDef).getOrElse(-1) + s"$filename:$lineNo:$lineNoEnd:$name:${argAsts.size}" + + /** The CDT utility method is unfortunately in a class that is marked as deprecated, however, + * this is because the CDT team would like to discourage its use but at the same time does not + * plan to remove this code. + */ + @nowarn + def nodeSignature(node: IASTNode): String = + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + val sig = if isExpandedFromMacro(node) then + val sig = getNodeSignature(node) + if sig.isEmpty then + node.getRawSignature + else + sig + else + node.getRawSignature + shortenCode(sig) + + private def isExpandedFromMacro(node: IASTNode): Boolean = expandedFromMacro(node).nonEmpty + + private def expandedFromMacro(node: IASTNode): Option[IASTMacroExpansionLocation] = + val locations = node.getNodeLocations.toList + val locationsSorted = node match + // For binary expressions the expansion locations may occur in any order. + // We manually sort them here to ignore this. + // TODO: This may also happen with other expressions that allow for multiple sub elements. + case _: IASTBinaryExpression => + locations.sortBy(_.isInstanceOf[IASTMacroExpansionLocation]) + case _ => locations + locationsSorted match + case (head: IASTMacroExpansionLocation) :: _ => Option(head) + case _ => None +end MacroHandler diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala index a3e00391..b28e4b82 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala @@ -3,14 +3,11 @@ package io.appthreat.c2cpg.datastructures import io.appthreat.c2cpg.astcreation.Defines import io.appthreat.x2cpg.datastructures.Global -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -object CGlobal extends Global { +object CGlobal extends Global: - def typesSeen(): List[String] = { - val types = usedTypes.keys().asScala.filterNot(_ == Defines.anyTypeName).toList - usedTypes.clear() - types - } - -} + def typesSeen(): List[String] = + val types = usedTypes.keys().asScala.filterNot(_ == Defines.anyTypeName).toList + usedTypes.clear() + types diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala index 5fe3f86f..1d5b5d0d 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala @@ -17,140 +17,141 @@ import org.slf4j.LoggerFactory import java.nio.file.{NoSuchFileException, Path} import scala.jdk.CollectionConverters.* -object CdtParser { - - private val logger = LoggerFactory.getLogger(classOf[CdtParser]) - - private case class ParseResult( - translationUnit: Option[IASTTranslationUnit], - preprocessorErrorCount: Int = 0, - problems: Int = 0, - failure: Option[Throwable] = None - ) - - def readFileAsFileContent(path: Path): FileContent = { - val lines = IOUtils.readLinesInFile(path).mkString("\n").toArray - FileContent.create(path.toString, true, lines) - } - -} - -class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorStatementsLogger { - - import CdtParser._ - - private val headerFileFinder = new HeaderFileFinder(config.inputPath) - private val parserConfig = ParserConfig.fromConfig(config) - private val definedSymbols = parserConfig.definedSymbols.asJava - private val includePaths = parserConfig.userIncludePaths - private val log = new DefaultLogService - - private var stayCpp: Boolean = false; - - private val cScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( - definedSymbols, - (includePaths ++ parserConfig.systemIncludePathsC).map(_.toString).toArray, - parserConfig.macroFiles.map(_.toString).toArray, - parserConfig.includeFiles.map(_.toString).toArray - ) - - private val cppScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( - definedSymbols, - (includePaths ++ parserConfig.systemIncludePathsCPP).map(_.toString).toArray, - parserConfig.macroFiles.map(_.toString).toArray, - parserConfig.includeFiles.map(_.toString).toArray - ) - - // Setup indexing - var index: Option[IIndex] = Option(EmptyCIndex.INSTANCE) - if (config.useProjectIndex) { - try { - val allProjects: Array[ICProject] = CoreModel.getDefault.getCModel.getCProjects - index = Option(CCorePlugin.getIndexManager.getIndex(allProjects)) - } catch { - case e: Throwable => - } - } - - // enables parsing of code behind disabled preprocessor defines: - private var opts: Int = ILanguage.OPTION_PARSE_INACTIVE_CODE - // instructs the parser to skip function and method bodies - if (!config.includeFunctionBodies) opts |= ILanguage.OPTION_SKIP_FUNCTION_BODIES - // performance optimization, allows the parser not to create image-locations - if (!config.includeImageLocations) opts |= ILanguage.OPTION_NO_IMAGE_LOCATIONS - - private def createParseLanguage(file: Path): ILanguage = { - if (FileDefaults.isCPPFile(file.toString)) { - GPPLanguage.getDefault - } else { - GCCLanguage.getDefault - } - } - - private def createScannerInfo(file: Path): ExtendedScannerInfo = - if (stayCpp || FileDefaults.isCPPFile(file.toString)) { - stayCpp = true - cppScannerInfo - } else cScannerInfo - - private def parseInternal(file: Path): ParseResult = { - val realPath = File(file) - if (realPath.isRegularFile) { // handling potentially broken symlinks - try { - val fileContent = readFileAsFileContent(realPath.path) - val fileContentProvider = new CustomFileContentProvider(headerFileFinder) - val lang = createParseLanguage(realPath.path) - val scannerInfo = createScannerInfo(realPath.path) - index match { - case Some(x) => if (x.isFullyInitialized) x.acquireReadLock() - case _ => - } - val translationUnit = - lang.getASTTranslationUnit(fileContent, scannerInfo, fileContentProvider, index.get, opts, log) - val problems = CPPVisitor.getProblems(translationUnit) - if (parserConfig.logProblems) logProblems(problems.toList) - if (parserConfig.logPreprocessor) logPreprocessorStatements(translationUnit) - ParseResult( - Option(translationUnit), - preprocessorErrorCount = translationUnit.getPreprocessorProblemsCount, - problems = problems.length - ) - } catch { - case u: UnsupportedClassVersionError => - logger.debug("c2cpg requires at least JRE-17 to run. Please check your Java Runtime Environment!", u) - System.exit(1) - ParseResult(None, failure = Option(u)) // return value to make the compiler happy - case e: Throwable => - ParseResult(None, failure = Option(e)) - } finally { - index match { - case Some(x) => x.releaseReadLock() - case _ => - } - } - } else { - ParseResult( - None, - failure = Option(new NoSuchFileException(s"File '$realPath' does not exist. Check for broken symlinks!")) - ) - } - } - - def preprocessorStatements(file: Path): Iterable[IASTPreprocessorStatement] = { - parse(file).map(t => preprocessorStatements(t)).getOrElse(Iterable.empty) - } - - def parse(file: Path): Option[IASTTranslationUnit] = { - val parseResult = parseInternal(file) - parseResult match { - case ParseResult(Some(t), c, p, _) => - Option(t) - case ParseResult(_, _, _, maybeThrowable) => - logger.warn( - s"Failed to parse '$file': ${maybeThrowable.map(extractParseException).getOrElse("Unknown parse error!")}" - ) - None - } - } - -} +object CdtParser: + + private val logger = LoggerFactory.getLogger(classOf[CdtParser]) + + private case class ParseResult( + translationUnit: Option[IASTTranslationUnit], + preprocessorErrorCount: Int = 0, + problems: Int = 0, + failure: Option[Throwable] = None + ) + + def readFileAsFileContent(path: Path): FileContent = + val lines = IOUtils.readLinesInFile(path).mkString("\n").toArray + FileContent.create(path.toString, true, lines) + +class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorStatementsLogger: + + import CdtParser.* + + private val headerFileFinder = new HeaderFileFinder(config.inputPath) + private val parserConfig = ParserConfig.fromConfig(config) + private val definedSymbols = parserConfig.definedSymbols.asJava + private val includePaths = parserConfig.userIncludePaths + private val log = new DefaultLogService + + private var stayCpp: Boolean = false; + + private val cScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( + definedSymbols, + (includePaths ++ parserConfig.systemIncludePathsC).map(_.toString).toArray, + parserConfig.macroFiles.map(_.toString).toArray, + parserConfig.includeFiles.map(_.toString).toArray + ) + + private val cppScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( + definedSymbols, + (includePaths ++ parserConfig.systemIncludePathsCPP).map(_.toString).toArray, + parserConfig.macroFiles.map(_.toString).toArray, + parserConfig.includeFiles.map(_.toString).toArray + ) + + // Setup indexing + var index: Option[IIndex] = Option(EmptyCIndex.INSTANCE) + if config.useProjectIndex then + try + val allProjects: Array[ICProject] = CoreModel.getDefault.getCModel.getCProjects + index = Option(CCorePlugin.getIndexManager.getIndex(allProjects)) + catch + case e: Throwable => + + // enables parsing of code behind disabled preprocessor defines: + private var opts: Int = ILanguage.OPTION_PARSE_INACTIVE_CODE + // instructs the parser to skip function and method bodies + if !config.includeFunctionBodies then opts |= ILanguage.OPTION_SKIP_FUNCTION_BODIES + // performance optimization, allows the parser not to create image-locations + if !config.includeImageLocations then opts |= ILanguage.OPTION_NO_IMAGE_LOCATIONS + + private def createParseLanguage(file: Path): ILanguage = + if FileDefaults.isCPPFile(file.toString) then + GPPLanguage.getDefault + else + GCCLanguage.getDefault + + private def createScannerInfo(file: Path): ExtendedScannerInfo = + if stayCpp || FileDefaults.isCPPFile(file.toString) then + stayCpp = true + cppScannerInfo + else cScannerInfo + + private def parseInternal(file: Path): ParseResult = + val realPath = File(file) + if realPath.isRegularFile then // handling potentially broken symlinks + try + val fileContent = readFileAsFileContent(realPath.path) + val fileContentProvider = new CustomFileContentProvider(headerFileFinder) + val lang = createParseLanguage(realPath.path) + val scannerInfo = createScannerInfo(realPath.path) + index match + case Some(x) => if x.isFullyInitialized then x.acquireReadLock() + case _ => + val translationUnit = + lang.getASTTranslationUnit( + fileContent, + scannerInfo, + fileContentProvider, + index.get, + opts, + log + ) + val problems = CPPVisitor.getProblems(translationUnit) + if parserConfig.logProblems then logProblems(problems.toList) + if parserConfig.logPreprocessor then logPreprocessorStatements(translationUnit) + ParseResult( + Option(translationUnit), + preprocessorErrorCount = translationUnit.getPreprocessorProblemsCount, + problems = problems.length + ) + catch + case u: UnsupportedClassVersionError => + logger.debug( + "c2cpg requires at least JRE-17 to run. Please check your Java Runtime Environment!", + u + ) + System.exit(1) + ParseResult( + None, + failure = Option(u) + ) // return value to make the compiler happy + case e: Throwable => + ParseResult(None, failure = Option(e)) + finally + index match + case Some(x) => x.releaseReadLock() + case _ => + else + ParseResult( + None, + failure = Option(new NoSuchFileException( + s"File '$realPath' does not exist. Check for broken symlinks!" + )) + ) + end if + end parseInternal + + def preprocessorStatements(file: Path): Iterable[IASTPreprocessorStatement] = + parse(file).map(t => preprocessorStatements(t)).getOrElse(Iterable.empty) + + def parse(file: Path): Option[IASTTranslationUnit] = + val parseResult = parseInternal(file) + parseResult match + case ParseResult(Some(t), c, p, _) => + Option(t) + case ParseResult(_, _, _, maybeThrowable) => + logger.warn( + s"Failed to parse '$file': ${maybeThrowable.map(extractParseException).getOrElse("Unknown parse error!")}" + ) + None +end CdtParser diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala index 4a7967e2..db3bcb18 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala @@ -2,37 +2,45 @@ package io.appthreat.c2cpg.parser import org.eclipse.cdt.core.index.IIndexFileLocation import org.eclipse.cdt.internal.core.parser.IMacroDictionary -import org.eclipse.cdt.internal.core.parser.scanner.{InternalFileContent, InternalFileContentProvider} +import org.eclipse.cdt.internal.core.parser.scanner.{ + InternalFileContent, + InternalFileContentProvider +} import org.slf4j.LoggerFactory import java.nio.file.Paths -class CustomFileContentProvider(headerFileFinder: HeaderFileFinder) extends InternalFileContentProvider { - - private val logger = LoggerFactory.getLogger(classOf[CustomFileContentProvider]) - - private def loadContent(path: String): InternalFileContent = { - val maybeFullPath = if (!getInclusionExists(path)) { - headerFileFinder.find(path) - } else { - Option(path) - } - maybeFullPath - .map { foundPath => - logger.debug(s"Loading header file '$foundPath'") - CdtParser.readFileAsFileContent(Paths.get(foundPath)).asInstanceOf[InternalFileContent] - } - .getOrElse { - logger.debug(s"Cannot find header file for '$path'") - null - } - - } - - override def getContentForInclusion(path: String, macroDictionary: IMacroDictionary): InternalFileContent = - loadContent(path) - - override def getContentForInclusion(ifl: IIndexFileLocation, astPath: String): InternalFileContent = - loadContent(astPath) - -} +class CustomFileContentProvider(headerFileFinder: HeaderFileFinder) + extends InternalFileContentProvider: + + private val logger = LoggerFactory.getLogger(classOf[CustomFileContentProvider]) + + private def loadContent(path: String): InternalFileContent = + val maybeFullPath = if !getInclusionExists(path) then + headerFileFinder.find(path) + else + Option(path) + maybeFullPath + .map { foundPath => + logger.debug(s"Loading header file '$foundPath'") + CdtParser.readFileAsFileContent(Paths.get(foundPath)).asInstanceOf[ + InternalFileContent + ] + } + .getOrElse { + logger.debug(s"Cannot find header file for '$path'") + null + } + + override def getContentForInclusion( + path: String, + macroDictionary: IMacroDictionary + ): InternalFileContent = + loadContent(path) + + override def getContentForInclusion( + ifl: IIndexFileLocation, + astPath: String + ): InternalFileContent = + loadContent(astPath) +end CustomFileContentProvider diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala index 593900fa..c484a5ea 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala @@ -1,35 +1,35 @@ package io.appthreat.c2cpg.parser -object DefaultDefines { - val DEFAULT_CALL_CONVENTIONS: Map[String, String] = Map( - "__fastcall" -> "__attribute((fastcall))", - "__cdecl" -> "__attribute((cdecl))", - "__pascal" -> "__attribute((pascal))", - "__vectorcall" -> "__attribute((vectorcall))", - "__clrcall" -> "__attribute((clrcall))", - "__stdcall" -> "__attribute((stdcall))", - "__thiscall" -> "__attribute((thiscall))", - "__declspec" -> "__attribute((declspec))", - "__restrict" -> "__attribute((restrict))", - "__sptr" -> "__attribute((sptr))", - "__uptr" -> "__attribute((uptr))", - "__syscall" -> "__attribute((syscall))", - "__oldcall" -> "__attribute((oldcall))", - "__unaligned" -> "__attribute((unaligned))", - "__w64" -> "__attribute((w64))", - "__asm" -> "__attribute((asm))", - "__based" -> "__attribute((based))", - "__interface" -> "__attribute((interface))", - "__event" -> "__attribute((event))", - "__hook" -> "__attribute((hook))", - "__unhook" -> "__attribute((unhook))", - "__raise" -> "__attribute((raise))", - "__try" -> "__attribute((try))", - "__except" -> "__attribute((except))", - "__finally" -> "__attribute((finally))", - "__m128" -> "__attribute((m128))", - "__m128d" -> "__attribute((m128d))", - "__m128i" -> "__attribute((m128i))", - "__m64" -> "__attribute((m64))" - ) -} +object DefaultDefines: + val DEFAULT_CALL_CONVENTIONS: Map[String, String] = Map( + "__fastcall" -> "__attribute((fastcall))", + "__cdecl" -> "__attribute((cdecl))", + "__pascal" -> "__attribute((pascal))", + "__vectorcall" -> "__attribute((vectorcall))", + "__clrcall" -> "__attribute((clrcall))", + "__stdcall" -> "__attribute((stdcall))", + "__thiscall" -> "__attribute((thiscall))", + "__declspec" -> "__attribute((declspec))", + "__restrict" -> "__attribute((restrict))", + "__sptr" -> "__attribute((sptr))", + "__uptr" -> "__attribute((uptr))", + "__syscall" -> "__attribute((syscall))", + "__oldcall" -> "__attribute((oldcall))", + "__unaligned" -> "__attribute((unaligned))", + "__w64" -> "__attribute((w64))", + "__asm" -> "__attribute((asm))", + "__based" -> "__attribute((based))", + "__interface" -> "__attribute((interface))", + "__event" -> "__attribute((event))", + "__hook" -> "__attribute((hook))", + "__unhook" -> "__attribute((unhook))", + "__raise" -> "__attribute((raise))", + "__try" -> "__attribute((try))", + "__except" -> "__attribute((except))", + "__finally" -> "__attribute((finally))", + "__m128" -> "__attribute((m128))", + "__m128d" -> "__attribute((m128d))", + "__m128i" -> "__attribute((m128i))", + "__m64" -> "__attribute((m64))" + ) +end DefaultDefines diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala index 85f09786..c45ff053 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala @@ -1,24 +1,24 @@ package io.appthreat.c2cpg.parser -object FileDefaults { +object FileDefaults: - val C_EXT: String = ".c" - val CPP_EXT: String = ".cpp" + val C_EXT: String = ".c" + val CPP_EXT: String = ".cpp" - private val CC_EXT = ".cc" - private val C_HEADER_EXT = ".h" - private val CPP_HEADER_EXT = ".hpp" - private val OTHER_HEADER_EXT = ".hh" + private val CC_EXT = ".cc" + private val C_HEADER_EXT = ".h" + private val CPP_HEADER_EXT = ".hpp" + private val OTHER_HEADER_EXT = ".hh" - val SOURCE_FILE_EXTENSIONS: Set[String] = Set(C_EXT, CC_EXT, CPP_EXT) + val SOURCE_FILE_EXTENSIONS: Set[String] = Set(C_EXT, CC_EXT, CPP_EXT) - val HEADER_FILE_EXTENSIONS: Set[String] = Set(C_HEADER_EXT, CPP_HEADER_EXT, OTHER_HEADER_EXT) + val HEADER_FILE_EXTENSIONS: Set[String] = Set(C_HEADER_EXT, CPP_HEADER_EXT, OTHER_HEADER_EXT) - private val CPP_FILE_EXTENSIONS = Set(CC_EXT, CPP_EXT, CPP_HEADER_EXT, ".ccm", ".cxxm", ".c++m") + private val CPP_FILE_EXTENSIONS = Set(CC_EXT, CPP_EXT, CPP_HEADER_EXT, ".ccm", ".cxxm", ".c++m") - def isHeaderFile(filePath: String): Boolean = - HEADER_FILE_EXTENSIONS.exists(filePath.endsWith) + def isHeaderFile(filePath: String): Boolean = + HEADER_FILE_EXTENSIONS.exists(filePath.endsWith) - def isCPPFile(filePath: String): Boolean = - CPP_FILE_EXTENSIONS.exists(filePath.endsWith) -} + def isCPPFile(filePath: String): Boolean = + CPP_FILE_EXTENSIONS.exists(filePath.endsWith) +end FileDefaults diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala index b5859758..5572d3aa 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala @@ -1,28 +1,26 @@ package io.appthreat.c2cpg.parser -import better.files._ +import better.files.* import io.appthreat.x2cpg.SourceFiles import org.jline.utils.Levenshtein import java.nio.file.Path -class HeaderFileFinder(root: String) { +class HeaderFileFinder(root: String): - private val nameToPathMap: Map[String, List[Path]] = SourceFiles - .determine(root, FileDefaults.HEADER_FILE_EXTENSIONS) - .map { p => - val file = File(p) - (file.name, file.path) - } - .groupBy(_._1) - .map(x => (x._1, x._2.map(_._2))) - - /** Given an unresolved header file, given as a non-existing absolute path, determine whether a header file with the - * same name can be found anywhere in the code base. - */ - def find(path: String): Option[String] = File(path).nameOption.flatMap { name => - val matches = nameToPathMap.getOrElse(name, List()) - matches.map(_.toString).sortBy(x => Levenshtein.distance(x, path)).headOption - } + private val nameToPathMap: Map[String, List[Path]] = SourceFiles + .determine(root, FileDefaults.HEADER_FILE_EXTENSIONS) + .map { p => + val file = File(p) + (file.name, file.path) + } + .groupBy(_._1) + .map(x => (x._1, x._2.map(_._2))) -} + /** Given an unresolved header file, given as a non-existing absolute path, determine whether a + * header file with the same name can be found anywhere in the code base. + */ + def find(path: String): Option[String] = File(path).nameOption.flatMap { name => + val matches = nameToPathMap.getOrElse(name, List()) + matches.map(_.toString).sortBy(x => Levenshtein.distance(x, path)).headOption + } diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala index 248d2558..b8b07d42 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala @@ -3,37 +3,35 @@ package io.appthreat.c2cpg.parser import org.eclipse.cdt.core.dom.ast.IASTProblem import org.slf4j.LoggerFactory -trait ParseProblemsLogger { +trait ParseProblemsLogger: - this: CdtParser => + this: CdtParser => - private val logger = LoggerFactory.getLogger(classOf[ParseProblemsLogger]) + private val logger = LoggerFactory.getLogger(classOf[ParseProblemsLogger]) - private def logProblemNode(node: IASTProblem): Unit = { - val text = s"""Parse problem '${node.getClass.getSimpleName}' occurred! + private def logProblemNode(node: IASTProblem): Unit = + val text = s"""Parse problem '${node.getClass.getSimpleName}' occurred! | Code: '${node.getRawSignature}' | File: '${node.getFileLocation.getFileName}' | Line: ${node.getFileLocation.getStartingLineNumber} | """.stripMargin - logger.debug(text) - } - - protected def logProblems(problems: List[IASTProblem]): Unit = { - problems.foreach(logProblemNode) - } - - /** The exception message might be null for parse failures due to ambiguous nodes that can't be resolved successfully. - * We extract better log messages for this case here. - */ - protected def extractParseException(exception: Throwable): String = Option(exception.getMessage) match { - case Some(message) => message - case None => - exception.getStackTrace - .collectFirst { - case stackTraceElement: StackTraceElement if stackTraceElement.getClassName.endsWith("ASTAmbiguousNode") => - "Could not resolve ambiguous node!" - } - .getOrElse(exception.getStackTrace.mkString(System.lineSeparator())) - } - -} + logger.debug(text) + + protected def logProblems(problems: List[IASTProblem]): Unit = + problems.foreach(logProblemNode) + + /** The exception message might be null for parse failures due to ambiguous nodes that can't be + * resolved successfully. We extract better log messages for this case here. + */ + protected def extractParseException(exception: Throwable): String = + Option(exception.getMessage) match + case Some(message) => message + case None => + exception.getStackTrace + .collectFirst { + case stackTraceElement: StackTraceElement + if stackTraceElement.getClassName.endsWith("ASTAmbiguousNode") => + "Could not resolve ambiguous node!" + } + .getOrElse(exception.getStackTrace.mkString(System.lineSeparator())) +end ParseProblemsLogger diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala index f25ce89e..732a38bf 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala @@ -5,37 +5,36 @@ import io.appthreat.c2cpg.utils.IncludeAutoDiscovery import java.nio.file.{Path, Paths} -object ParserConfig { +object ParserConfig: - def empty: ParserConfig = - ParserConfig( - Set.empty, - Set.empty, - Set.empty, - Set.empty, - Map.empty, - Set.empty, - logProblems = false, - logPreprocessor = false - ) - - def fromConfig(config: Config): ParserConfig = ParserConfig( - config.includeFiles.map(Paths.get(_).toAbsolutePath), - config.includePaths.map(Paths.get(_).toAbsolutePath), - IncludeAutoDiscovery.discoverIncludePathsC(config), - IncludeAutoDiscovery.discoverIncludePathsCPP(config), - config.defines.map { - case define if define.contains("=") => - val s = define.split("=") - s.head -> s(1) - case define => define -> "true" - }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, - config.macroFiles.map(Paths.get(_).toAbsolutePath), - config.logProblems, - config.logPreprocessor - ) + def empty: ParserConfig = + ParserConfig( + Set.empty, + Set.empty, + Set.empty, + Set.empty, + Map.empty, + Set.empty, + logProblems = false, + logPreprocessor = false + ) -} + def fromConfig(config: Config): ParserConfig = ParserConfig( + config.includeFiles.map(Paths.get(_).toAbsolutePath), + config.includePaths.map(Paths.get(_).toAbsolutePath), + IncludeAutoDiscovery.discoverIncludePathsC(config), + IncludeAutoDiscovery.discoverIncludePathsCPP(config), + config.defines.map { + case define if define.contains("=") => + val s = define.split("=") + s.head -> s(1) + case define => define -> "true" + }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, + config.macroFiles.map(Paths.get(_).toAbsolutePath), + config.logProblems, + config.logPreprocessor + ) +end ParserConfig case class ParserConfig( includeFiles: Set[Path], diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala index b9d96060..2acc1bc6 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala @@ -1,46 +1,43 @@ package io.appthreat.c2cpg.parser import org.eclipse.cdt.core.dom.ast.{ - IASTPreprocessorFunctionStyleMacroDefinition, - IASTPreprocessorIfStatement, - IASTPreprocessorIfdefStatement, - IASTPreprocessorStatement, - IASTTranslationUnit + IASTPreprocessorFunctionStyleMacroDefinition, + IASTPreprocessorIfStatement, + IASTPreprocessorIfdefStatement, + IASTPreprocessorStatement, + IASTTranslationUnit } import org.slf4j.LoggerFactory -trait PreprocessorStatementsLogger { +trait PreprocessorStatementsLogger: - this: CdtParser => + this: CdtParser => - private val logger = LoggerFactory.getLogger(classOf[PreprocessorStatementsLogger]) + private val logger = LoggerFactory.getLogger(classOf[PreprocessorStatementsLogger]) - private def logPreprocessorStatement(node: IASTPreprocessorStatement): Unit = { - val text = s"""Preprocessor statement '${node.getClass.getSimpleName}' found! + private def logPreprocessorStatement(node: IASTPreprocessorStatement): Unit = + val text = s"""Preprocessor statement '${node.getClass.getSimpleName}' found! | Code: '${node.getRawSignature}' | File: '${node.getFileLocation.getFileName}' | Line: ${node.getFileLocation.getStartingLineNumber}""".stripMargin - val additionalInfo = node match { - case s: IASTPreprocessorFunctionStyleMacroDefinition => - s""" + val additionalInfo = node match + case s: IASTPreprocessorFunctionStyleMacroDefinition => + s""" | Parameter: ${s.getParameters.map(_.getRawSignature).mkString(", ")} | Expansion: ${s.getExpansion}""".stripMargin - case s: IASTPreprocessorIfStatement => - s""" + case s: IASTPreprocessorIfStatement => + s""" | Defined: ${s.taken()}""".stripMargin - case s: IASTPreprocessorIfdefStatement => - s""" + case s: IASTPreprocessorIfdefStatement => + s""" | Defined: ${s.taken()}""".stripMargin - case _ => "" - } - logger.debug(s"$text$additionalInfo") - } + case _ => "" + logger.debug(s"$text$additionalInfo") - protected def preprocessorStatements(translationUnit: IASTTranslationUnit): Iterable[IASTPreprocessorStatement] = - translationUnit.getAllPreprocessorStatements + protected def preprocessorStatements(translationUnit: IASTTranslationUnit) + : Iterable[IASTPreprocessorStatement] = + translationUnit.getAllPreprocessorStatements - protected def logPreprocessorStatements(translationUnit: IASTTranslationUnit): Unit = { - preprocessorStatements(translationUnit).foreach(logPreprocessorStatement) - } - -} + protected def logPreprocessorStatements(translationUnit: IASTTranslationUnit): Unit = + preprocessorStatements(translationUnit).foreach(logPreprocessorStatement) +end PreprocessorStatementsLogger diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala index 0dd31da1..c97dc17d 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala @@ -14,43 +14,41 @@ import java.util.regex.Pattern import scala.util.matching.Regex class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) - extends ConcurrentWriterCpgPass[String](cpg) { - - private val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] = new ConcurrentHashMap() - private val parser: CdtParser = new CdtParser(config) - - private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) - private val DefaultIgnoredFolders: List[Regex] = List( - "\\..*".r, - s"(.*[$EscapedFileSeparator])?tests?[$EscapedFileSeparator].*".r, - s"(.*[$EscapedFileSeparator])?CMakeFiles[$EscapedFileSeparator].*".r - ) - - override def generateParts(): Array[String] = - SourceFiles - .determine( - config.inputPath, - FileDefaults.SOURCE_FILE_EXTENSIONS ++ FileDefaults.HEADER_FILE_EXTENSIONS, - config.withDefaultIgnoredFilesRegex(DefaultIgnoredFolders) - ) - .sortWith(_.compareToIgnoreCase(_) > 0) - .toArray - - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = { - val path = Paths.get(filename).toAbsolutePath - val relPath = SourceFiles.toRelativePath(path.toString, config.inputPath) - try { - val parseResult = parser.parse(path) - parseResult match { - case Some(translationUnit) => - val localDiff = - new AstCreator(relPath, config, translationUnit, file2OffsetTable)(config.schemaValidation).createAst() - diffGraph.absorb(localDiff) - case None => - } - } catch { - case e: Throwable => - } - } - -} + extends ConcurrentWriterCpgPass[String](cpg): + + private val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] = new ConcurrentHashMap() + private val parser: CdtParser = new CdtParser(config) + + private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) + private val DefaultIgnoredFolders: List[Regex] = List( + "\\..*".r, + s"(.*[$EscapedFileSeparator])?tests?[$EscapedFileSeparator].*".r, + s"(.*[$EscapedFileSeparator])?CMakeFiles[$EscapedFileSeparator].*".r + ) + + override def generateParts(): Array[String] = + SourceFiles + .determine( + config.inputPath, + FileDefaults.SOURCE_FILE_EXTENSIONS ++ FileDefaults.HEADER_FILE_EXTENSIONS, + config.withDefaultIgnoredFilesRegex(DefaultIgnoredFolders) + ) + .sortWith(_.compareToIgnoreCase(_) > 0) + .toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = + val path = Paths.get(filename).toAbsolutePath + val relPath = SourceFiles.toRelativePath(path.toString, config.inputPath) + try + val parseResult = parser.parse(path) + parseResult match + case Some(translationUnit) => + val localDiff = + new AstCreator(relPath, config, translationUnit, file2OffsetTable)( + config.schemaValidation + ).createAst() + diffGraph.absorb(localDiff) + case None => + catch + case e: Throwable => +end AstCreationPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala index c2c9d232..f8bfdef3 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala @@ -4,25 +4,24 @@ import better.files.File import io.appthreat.x2cpg.passes.frontend.XConfigFileCreationPass import io.shiftleft.codepropertygraph.Cpg -class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg) { +class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // TOML files - extensionFilter(".toml"), - // INI files - extensionFilter(".ini"), - // YAML files - extensionFilter(".yaml"), - extensionFilter(".lock"), - pathEndFilter("conanfile.txt"), - extensionFilter(".cmake"), - extensionFilter(".build"), - pathEndFilter("CMakeLists.txt"), - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json"), - pathEndFilter("setup.cfg"), - pathEndFilter("setup.py") - ) - -} + override val configFileFilters: List[File => Boolean] = List( + // TOML files + extensionFilter(".toml"), + // INI files + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".lock"), + pathEndFilter("conanfile.txt"), + extensionFilter(".cmake"), + extensionFilter(".build"), + pathEndFilter("CMakeLists.txt"), + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json"), + pathEndFilter("setup.cfg"), + pathEndFilter("setup.py") + ) +end ConfigFileCreationPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala index 11c57f69..e66ef80b 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala @@ -4,31 +4,34 @@ import io.appthreat.c2cpg.Config import io.appthreat.c2cpg.parser.{CdtParser, FileDefaults} import io.appthreat.x2cpg.SourceFiles import org.eclipse.cdt.core.dom.ast.{ - IASTPreprocessorIfStatement, - IASTPreprocessorIfdefStatement, - IASTPreprocessorStatement + IASTPreprocessorIfStatement, + IASTPreprocessorIfdefStatement, + IASTPreprocessorStatement } import java.nio.file.Paths import scala.collection.parallel.CollectionConverters.ImmutableIterableIsParallelizable import scala.collection.parallel.immutable.ParIterable -class PreprocessorPass(config: Config) { +class PreprocessorPass(config: Config): - private val parser = new CdtParser(config) + private val parser = new CdtParser(config) - def run(): ParIterable[String] = - SourceFiles.determine(config.inputPath, FileDefaults.SOURCE_FILE_EXTENSIONS).par.flatMap(runOnPart) + def run(): ParIterable[String] = + SourceFiles.determine(config.inputPath, FileDefaults.SOURCE_FILE_EXTENSIONS).par.flatMap( + runOnPart + ) - private def preprocessorStatement2String(stmt: IASTPreprocessorStatement): Option[String] = stmt match { - case s: IASTPreprocessorIfStatement => - Option(s"${s.getCondition.mkString}${if (s.taken()) "=true" else ""}") - case s: IASTPreprocessorIfdefStatement => - Option(s"${s.getCondition.mkString}${if (s.taken()) "=true" else ""}") - case _ => None - } + private def preprocessorStatement2String(stmt: IASTPreprocessorStatement): Option[String] = + stmt match + case s: IASTPreprocessorIfStatement => + Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") + case s: IASTPreprocessorIfdefStatement => + Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") + case _ => None - private def runOnPart(filename: String): Iterable[String] = - parser.preprocessorStatements(Paths.get(filename)).flatMap(preprocessorStatement2String).toSet - -} + private def runOnPart(filename: String): Iterable[String] = + parser.preprocessorStatements(Paths.get(filename)).flatMap( + preprocessorStatement2String + ).toSet +end PreprocessorPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala index f5308f27..6a50d8f7 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala @@ -11,55 +11,56 @@ import io.appthreat.x2cpg.{Ast, ValidationMode} import io.appthreat.x2cpg.utils.NodeBuilders.newMethodReturnNode import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -class TypeDeclNodePass(cpg: Cpg)(implicit withSchemaValidation: ValidationMode) extends CpgPass(cpg) { +class TypeDeclNodePass(cpg: Cpg)(implicit withSchemaValidation: ValidationMode) + extends CpgPass(cpg): - private val filename: String = "" - private val globalName: String = NamespaceTraversal.globalNamespaceName - private val fullName: String = MetaDataPass.getGlobalNamespaceBlockFullName(Option(filename)) + private val filename: String = "" + private val globalName: String = NamespaceTraversal.globalNamespaceName + private val fullName: String = MetaDataPass.getGlobalNamespaceBlockFullName(Option(filename)) - private val typeDeclFullNames: Set[String] = cpg.typeDecl.fullName.toSetImmutable + private val typeDeclFullNames: Set[String] = cpg.typeDecl.fullName.toSetImmutable - private def createGlobalAst(): Ast = { - val includesFile = NewFile().name(filename) - val namespaceBlock = NewNamespaceBlock() - .name(globalName) - .fullName(fullName) - .filename(filename) - val fakeGlobalIncludesMethod = - NewMethod() - .name(globalName) - .code(globalName) - .fullName(fullName) - .filename(filename) - .lineNumber(1) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(fullName) - val blockNode = NewBlock().typeFullName(Defines.anyTypeName) - val methodReturn = newMethodReturnNode(Defines.anyTypeName, line = None, column = None) - Ast(includesFile).withChild( - Ast(namespaceBlock) - .withChild(Ast(fakeGlobalIncludesMethod).withChild(Ast(blockNode)).withChild(Ast(methodReturn))) - ) - } + private def createGlobalAst(): Ast = + val includesFile = NewFile().name(filename) + val namespaceBlock = NewNamespaceBlock() + .name(globalName) + .fullName(fullName) + .filename(filename) + val fakeGlobalIncludesMethod = + NewMethod() + .name(globalName) + .code(globalName) + .fullName(fullName) + .filename(filename) + .lineNumber(1) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(fullName) + val blockNode = NewBlock().typeFullName(Defines.anyTypeName) + val methodReturn = newMethodReturnNode(Defines.anyTypeName, line = None, column = None) + Ast(includesFile).withChild( + Ast(namespaceBlock) + .withChild( + Ast(fakeGlobalIncludesMethod).withChild(Ast(blockNode)).withChild(Ast(methodReturn)) + ) + ) + end createGlobalAst - private def typeNeedsTypeDeclStub(t: Type): Boolean = - !typeDeclFullNames.contains(t.typeDeclFullName) + private def typeNeedsTypeDeclStub(t: Type): Boolean = + !typeDeclFullNames.contains(t.typeDeclFullName) - override def run(dstGraph: DiffGraphBuilder): Unit = { - var hadMissingTypeDecl = false - cpg.typ.filter(typeNeedsTypeDeclStub).foreach { t => - val newTypeDecl = NewTypeDecl() - .name(t.name) - .fullName(t.typeDeclFullName) - .code(t.name) - .isExternal(true) - .filename(filename) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(fullName) - dstGraph.addNode(newTypeDecl) - hadMissingTypeDecl = true - } - if (hadMissingTypeDecl) Ast.storeInDiffGraph(createGlobalAst(), dstGraph) - } - -} + override def run(dstGraph: DiffGraphBuilder): Unit = + var hadMissingTypeDecl = false + cpg.typ.filter(typeNeedsTypeDeclStub).foreach { t => + val newTypeDecl = NewTypeDecl() + .name(t.name) + .fullName(t.typeDeclFullName) + .code(t.name) + .isExternal(true) + .filename(filename) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(fullName) + dstGraph.addNode(newTypeDecl) + hadMissingTypeDecl = true + } + if hadMissingTypeDecl then Ast.storeInDiffGraph(createGlobalAst(), dstGraph) +end TypeDeclNodePass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala index 53bfbc55..4fffc5fa 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala @@ -4,27 +4,25 @@ import scala.collection.mutable import scala.sys.process.{Process, ProcessLogger} import scala.util.{Failure, Success, Try} -object ExternalCommand { +object ExternalCommand: - private val IS_WIN: Boolean = - scala.util.Properties.isWin + private val IS_WIN: Boolean = + scala.util.Properties.isWin - private val shellPrefix: Seq[String] = - if (IS_WIN) "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil + private val shellPrefix: Seq[String] = + if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil - def run(command: String): Try[Seq[String]] = { - val result = mutable.ArrayBuffer.empty[String] - val lineHandler: String => Unit = result.addOne - Process(shellPrefix :+ command).!(ProcessLogger(lineHandler, lineHandler)) match { - case 0 => - Success(result.toSeq) - case 1 - if IS_WIN && - command != IncludeAutoDiscovery.GCC_VERSION_COMMAND && - IncludeAutoDiscovery.gccAvailable() => - Success(result.toSeq) - case _ => - Failure(new RuntimeException(result.mkString(System.lineSeparator()))) - } - } -} + def run(command: String): Try[Seq[String]] = + val result = mutable.ArrayBuffer.empty[String] + val lineHandler: String => Unit = result.addOne + Process(shellPrefix :+ command).!(ProcessLogger(lineHandler, lineHandler)) match + case 0 => + Success(result.toSeq) + case 1 + if IS_WIN && + command != IncludeAutoDiscovery.GCC_VERSION_COMMAND && + IncludeAutoDiscovery.gccAvailable() => + Success(result.toSeq) + case _ => + Failure(new RuntimeException(result.mkString(System.lineSeparator()))) +end ExternalCommand diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala index 5a80b63a..4856d6c0 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala @@ -7,93 +7,92 @@ import java.nio.file.{Path, Paths} import scala.util.Failure import scala.util.Success -object IncludeAutoDiscovery { - - private val logger = LoggerFactory.getLogger(IncludeAutoDiscovery.getClass) - - private val IS_WIN = scala.util.Properties.isWin - - val GCC_VERSION_COMMAND = "gcc --version" - - private val CPP_INCLUDE_COMMAND = - if (IS_WIN) "gcc -xc++ -E -v . -o nul" else "gcc -xc++ -E -v /dev/null -o /dev/null" - - private val C_INCLUDE_COMMAND = - if (IS_WIN) "gcc -xc -E -v . -o nul" else "gcc -xc -E -v /dev/null -o /dev/null" - - // Only check once - private var isGccAvailable: Option[Boolean] = None - - // Only discover them once - private var systemIncludePathsC: Set[Path] = Set.empty - private var systemIncludePathsCPP: Set[Path] = Set.empty - - private def checkForGcc(): Boolean = { - logger.debug("Checking gcc ...") - ExternalCommand.run(GCC_VERSION_COMMAND) match { - case Success(result) => - logger.debug(s"GCC is available: ${result.mkString(System.lineSeparator())}") - true - case _ => - logger.warn("GCC is not installed. Discovery of system include paths will not be available.") - false - } - } - - def gccAvailable(): Boolean = isGccAvailable match { - case Some(value) => - value - case None => - isGccAvailable = Option(checkForGcc()) - isGccAvailable.get - } - - private def extractPaths(output: Seq[String]): Set[Path] = { - val startIndex = - output.indexWhere(_.contains("#include")) + 2 - val endIndex = - if (IS_WIN) output.indexWhere(_.startsWith("End of search list.")) - 1 - else output.indexWhere(_.startsWith("COMPILER_PATH")) - 1 - output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath()).toSet - } - - private def discoverPaths(command: String): Set[Path] = ExternalCommand.run(command) match { - case Success(output) => extractPaths(output) - case Failure(exception) => - logger.warn(s"Unable to discover system include paths. Running '$command' failed.", exception) - Set.empty - } - - def discoverIncludePathsC(config: Config): Set[Path] = { - if (config.includePathsAutoDiscovery && systemIncludePathsC.nonEmpty) { - systemIncludePathsC - } else if (config.includePathsAutoDiscovery && systemIncludePathsC.isEmpty && gccAvailable()) { - val includePathsC = discoverPaths(C_INCLUDE_COMMAND) - if (includePathsC.nonEmpty) { - logger.debug(s"Using the following C system include paths:${includePathsC - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}") - } - systemIncludePathsC = includePathsC - includePathsC - } else { - Set.empty - } - } - - def discoverIncludePathsCPP(config: Config): Set[Path] = { - if (config.includePathsAutoDiscovery && systemIncludePathsCPP.nonEmpty) { - systemIncludePathsCPP - } else if (config.includePathsAutoDiscovery && systemIncludePathsCPP.isEmpty && gccAvailable()) { - val includePathsCPP = discoverPaths(CPP_INCLUDE_COMMAND) - if (includePathsCPP.nonEmpty) { - logger.debug(s"Using the following CPP system include paths:${includePathsCPP - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}") - } - systemIncludePathsCPP = includePathsCPP - includePathsCPP - } else { - Set.empty - } - } - -} +object IncludeAutoDiscovery: + + private val logger = LoggerFactory.getLogger(IncludeAutoDiscovery.getClass) + + private val IS_WIN = scala.util.Properties.isWin + + val GCC_VERSION_COMMAND = "gcc --version" + + private val CPP_INCLUDE_COMMAND = + if IS_WIN then "gcc -xc++ -E -v . -o nul" else "gcc -xc++ -E -v /dev/null -o /dev/null" + + private val C_INCLUDE_COMMAND = + if IS_WIN then "gcc -xc -E -v . -o nul" else "gcc -xc -E -v /dev/null -o /dev/null" + + // Only check once + private var isGccAvailable: Option[Boolean] = None + + // Only discover them once + private var systemIncludePathsC: Set[Path] = Set.empty + private var systemIncludePathsCPP: Set[Path] = Set.empty + + private def checkForGcc(): Boolean = + logger.debug("Checking gcc ...") + ExternalCommand.run(GCC_VERSION_COMMAND) match + case Success(result) => + logger.debug(s"GCC is available: ${result.mkString(System.lineSeparator())}") + true + case _ => + logger.warn( + "GCC is not installed. Discovery of system include paths will not be available." + ) + false + + def gccAvailable(): Boolean = isGccAvailable match + case Some(value) => + value + case None => + isGccAvailable = Option(checkForGcc()) + isGccAvailable.get + + private def extractPaths(output: Seq[String]): Set[Path] = + val startIndex = + output.indexWhere(_.contains("#include")) + 2 + val endIndex = + if IS_WIN then output.indexWhere(_.startsWith("End of search list.")) - 1 + else output.indexWhere(_.startsWith("COMPILER_PATH")) - 1 + output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath()).toSet + + private def discoverPaths(command: String): Set[Path] = ExternalCommand.run(command) match + case Success(output) => extractPaths(output) + case Failure(exception) => + logger.warn( + s"Unable to discover system include paths. Running '$command' failed.", + exception + ) + Set.empty + + def discoverIncludePathsC(config: Config): Set[Path] = + if config.includePathsAutoDiscovery && systemIncludePathsC.nonEmpty then + systemIncludePathsC + else if config.includePathsAutoDiscovery && systemIncludePathsC.isEmpty && gccAvailable() + then + val includePathsC = discoverPaths(C_INCLUDE_COMMAND) + if includePathsC.nonEmpty then + logger.debug( + s"Using the following C system include paths:${includePathsC + .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" + ) + systemIncludePathsC = includePathsC + includePathsC + else + Set.empty + + def discoverIncludePathsCPP(config: Config): Set[Path] = + if config.includePathsAutoDiscovery && systemIncludePathsCPP.nonEmpty then + systemIncludePathsCPP + else if config.includePathsAutoDiscovery && systemIncludePathsCPP.isEmpty && gccAvailable() + then + val includePathsCPP = discoverPaths(CPP_INCLUDE_COMMAND) + if includePathsCPP.nonEmpty then + logger.debug( + s"Using the following CPP system include paths:${includePathsCPP + .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" + ) + systemIncludePathsCPP = includePathsCPP + includePathsCPP + else + Set.empty +end IncludeAutoDiscovery diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala index 0523c63f..d628c6cc 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala @@ -4,89 +4,82 @@ import org.slf4j.LoggerFactory import scala.collection.concurrent.TrieMap -object Report { - - private val logger = LoggerFactory.getLogger(Report.getClass) - - private type FileName = String - - private type Reports = TrieMap[FileName, ReportEntry] - - private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long) { - def toSeq: Seq[String] = { - val lines = loc.toString - val dur = if (duration == 0) "-" else TimeUtils.pretty(duration) - val wasParsed = if (parsed) "yes" else "no" - val gotCpg = if (cpgGen) "yes" else "no" - Seq(lines, wasParsed, gotCpg, dur) - } - } - -} - -class Report { - - import Report._ - - private val reports: Reports = TrieMap.empty - - private def formatTable(table: Seq[Seq[String]]): String = { - if (table.isEmpty) "" - else { - // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) - val colWidths = - table.transpose.map(_.map(cell => if (cell == null) 0 else cell.length).max + 2) - // Format each row - val rows = table.map( - _.zip(colWidths) - .map { case (item, size) => s" %-${size - 1}s".format(item) } - .mkString("|", "|", "|") - ) - // Formatted separator row, used to separate the header and draw table borders - val separator = colWidths.map("-" * _).mkString("+", "+", "+") - // Put the table together and return - val header = rows.head - val content = rows.tail.take(rows.tail.size - 1) - val footer = rows.tail.last - (separator +: header +: separator +: content :+ separator :+ footer :+ separator) - .mkString("\n") - } - } - - def print(): Unit = { - val rows = reports.toSeq - .sortBy(_._1) - .zipWithIndex - .view - .map { case ((file, sum), index) => - s"${index + 1}" +: file +: sum.toSeq - } - .toSeq - val numOfReports = reports.size - val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) - val footer = Seq( - Seq( - "Total", - "", - s"${reports.map(_._2.loc).sum}", - s"${reports.count(_._2.parsed)}/$numOfReports", - s"${reports.count(_._2.cpgGen)}/$numOfReports", - "" - ) - ) - val table = header ++ rows ++ footer - logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") - } - - def addReportInfo( - fileName: FileName, - loc: Int, - parsed: Boolean = false, - cpgGen: Boolean = false, - duration: Long = 0 - ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) - - def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = - reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) - -} +object Report: + + private val logger = LoggerFactory.getLogger(Report.getClass) + + private type FileName = String + + private type Reports = TrieMap[FileName, ReportEntry] + + private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): + def toSeq: Seq[String] = + val lines = loc.toString + val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) + val wasParsed = if parsed then "yes" else "no" + val gotCpg = if cpgGen then "yes" else "no" + Seq(lines, wasParsed, gotCpg, dur) + +class Report: + + import Report.* + + private val reports: Reports = TrieMap.empty + + private def formatTable(table: Seq[Seq[String]]): String = + if table.isEmpty then "" + else + // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) + val colWidths = + table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) + // Format each row + val rows = table.map( + _.zip(colWidths) + .map { case (item, size) => s" %-${size - 1}s".format(item) } + .mkString("|", "|", "|") + ) + // Formatted separator row, used to separate the header and draw table borders + val separator = colWidths.map("-" * _).mkString("+", "+", "+") + // Put the table together and return + val header = rows.head + val content = rows.tail.take(rows.tail.size - 1) + val footer = rows.tail.last + (separator +: header +: separator +: content :+ separator :+ footer :+ separator) + .mkString("\n") + + def print(): Unit = + val rows = reports.toSeq + .sortBy(_._1) + .zipWithIndex + .view + .map { case ((file, sum), index) => + s"${index + 1}" +: file +: sum.toSeq + } + .toSeq + val numOfReports = reports.size + val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) + val footer = Seq( + Seq( + "Total", + "", + s"${reports.map(_._2.loc).sum}", + s"${reports.count(_._2.parsed)}/$numOfReports", + s"${reports.count(_._2.cpgGen)}/$numOfReports", + "" + ) + ) + val table = header ++ rows ++ footer + logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") + end print + + def addReportInfo( + fileName: FileName, + loc: Int, + parsed: Boolean = false, + cpgGen: Boolean = false, + duration: Long = 0 + ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) + + def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = + reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) +end Report diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala index 1d2492e1..86422f26 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala @@ -1,57 +1,52 @@ package io.appthreat.c2cpg.utils import java.util.Locale -import scala.concurrent.duration._ - -object TimeUtils { - - /** Measures elapsed time for executing a block in nanoseconds */ - def time[R](block: => R): (R, Long) = { - val t0 = System.nanoTime() - val result = block - val t1 = System.nanoTime() - val elapsed = t1 - t0 - (result, elapsed) - } - - /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ - def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) - - private def pretty(duration: Duration): String = - duration match { - case d: FiniteDuration => - val nanos = d.toNanos - val unit = chooseUnit(nanos) - val value = nanos.toDouble / NANOSECONDS.convert(1, unit) - - s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) - - case Duration.MinusInf => s"-∞ (minus infinity)" - case Duration.Inf => s"∞ (infinity)" - case _ => "undefined" - } - - private def chooseUnit(nanos: Long): TimeUnit = { - val d = nanos.nanos - - if (d.toDays > 0) DAYS - else if (d.toHours > 0) HOURS - else if (d.toMinutes > 0) MINUTES - else if (d.toSeconds > 0) SECONDS - else if (d.toMillis > 0) MILLISECONDS - else if (d.toMicros > 0) MICROSECONDS - else NANOSECONDS - } - - private def abbreviate(unit: TimeUnit): String = - unit match { - case NANOSECONDS => "ns" - case MICROSECONDS => "μs" - case MILLISECONDS => "ms" - case SECONDS => "s" - case MINUTES => "min" - case HOURS => "h" - case DAYS => "d" - } - -} +import scala.concurrent.duration.* + +object TimeUtils: + + /** Measures elapsed time for executing a block in nanoseconds */ + def time[R](block: => R): (R, Long) = + val t0 = System.nanoTime() + val result = block + val t1 = System.nanoTime() + val elapsed = t1 - t0 + (result, elapsed) + + /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ + def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) + + private def pretty(duration: Duration): String = + duration match + case d: FiniteDuration => + val nanos = d.toNanos + val unit = chooseUnit(nanos) + val value = nanos.toDouble / NANOSECONDS.convert(1, unit) + + s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) + + case Duration.MinusInf => s"-∞ (minus infinity)" + case Duration.Inf => s"∞ (infinity)" + case _ => "undefined" + + private def chooseUnit(nanos: Long): TimeUnit = + val d = nanos.nanos + + if d.toDays > 0 then DAYS + else if d.toHours > 0 then HOURS + else if d.toMinutes > 0 then MINUTES + else if d.toSeconds > 0 then SECONDS + else if d.toMillis > 0 then MILLISECONDS + else if d.toMicros > 0 then MICROSECONDS + else NANOSECONDS + + private def abbreviate(unit: TimeUnit): String = + unit match + case NANOSECONDS => "ns" + case MICROSECONDS => "μs" + case MILLISECONDS => "ms" + case SECONDS => "s" + case MINUTES => "min" + case HOURS => "h" + case DAYS => "d" +end TimeUtils diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala index f9edb55f..655bd2a3 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala @@ -2,11 +2,11 @@ package io.appthreat.javasrc2cpg import better.files.File import io.appthreat.javasrc2cpg.passes.{ - AstCreationPass, - ConfigFileCreationPass, - JavaTypeHintCallLinker, - JavaTypeRecoveryPass, - TypeInferencePass + AstCreationPass, + ConfigFileCreationPass, + JavaTypeHintCallLinker, + JavaTypeRecoveryPass, + TypeInferencePass } import io.appthreat.x2cpg.X2Cpg.withNewEmptyCpg import io.appthreat.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, XTypeRecoveryConfig} @@ -16,67 +16,66 @@ import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.passes.CpgPassBase import org.slf4j.LoggerFactory -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Try import scala.util.matching.Regex -class JavaSrc2Cpg extends X2CpgFrontend[Config] { - import JavaSrc2Cpg._ +class JavaSrc2Cpg extends X2CpgFrontend[Config]: + import JavaSrc2Cpg.* - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - override def createCpg(config: Config): Try[Cpg] = { - withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => - new MetaDataPass(cpg, language, config.inputPath).createAndApply() - val astCreationPass = new AstCreationPass(config, cpg) - astCreationPass.createAndApply() - new ConfigFileCreationPass(cpg).createAndApply() - if (!config.skipTypeInfPass) { - TypeNodePass.withRegisteredTypes(astCreationPass.global.usedTypes.keys().asScala.toList, cpg).createAndApply() - new TypeInferencePass(cpg).createAndApply() - } - } - } - -} + override def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + new MetaDataPass(cpg, language, config.inputPath).createAndApply() + val astCreationPass = new AstCreationPass(config, cpg) + astCreationPass.createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + if !config.skipTypeInfPass then + TypeNodePass.withRegisteredTypes( + astCreationPass.global.usedTypes.keys().asScala.toList, + cpg + ).createAndApply() + new TypeInferencePass(cpg).createAndApply() + } -object JavaSrc2Cpg { - val language: String = Languages.JAVASRC - private val logger = LoggerFactory.getLogger(this.getClass) +object JavaSrc2Cpg: + val language: String = Languages.JAVASRC + private val logger = LoggerFactory.getLogger(this.getClass) - val sourceFileExtensions: Set[String] = Set(".java") + val sourceFileExtensions: Set[String] = Set(".java") - val DefaultIgnoredFilesRegex: List[Regex] = List(".git", ".mvn", "test").flatMap { directory => - List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) - } + val DefaultIgnoredFilesRegex: List[Regex] = List(".git", ".mvn", "test").flatMap { directory => + List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) + } - val DefaultConfig: Config = - Config().withDefaultIgnoredFilesRegex(DefaultIgnoredFilesRegex) + val DefaultConfig: Config = + Config().withDefaultIgnoredFilesRegex(DefaultIgnoredFilesRegex) - def apply(): JavaSrc2Cpg = new JavaSrc2Cpg() + def apply(): JavaSrc2Cpg = new JavaSrc2Cpg() - def typeRecoveryPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = { - List( - new JavaTypeRecoveryPass(cpg, XTypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes))), - new JavaTypeHintCallLinker(cpg) - ) - } + def typeRecoveryPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = + List( + new JavaTypeRecoveryPass( + cpg, + XTypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes)) + ), + new JavaTypeHintCallLinker(cpg) + ) - def showEnv(): Unit = { - val value = - JavaSrcEnvVar.values.foreach { envVar => - val currentValue = Option(System.getenv(envVar.name)).getOrElse("") - println(s"${envVar.name}:") - println(s" Description : ${envVar.description}") - println(s" Current value: $currentValue") - } - } + def showEnv(): Unit = + val value = + JavaSrcEnvVar.values.foreach { envVar => + val currentValue = Option(System.getenv(envVar.name)).getOrElse("") + println(s"${envVar.name}:") + println(s" Description : ${envVar.description}") + println(s" Current value: $currentValue") + } - enum JavaSrcEnvVar(val name: String, val description: String) { - case JdkPath - extends JavaSrcEnvVar( - "JAVASRC_JDK_PATH", - "Path to the JDK home used for retrieving type information about builtin Java types." - ) - } -} + enum JavaSrcEnvVar(val name: String, val description: String): + case JdkPath + extends JavaSrcEnvVar( + "JAVASRC_JDK_PATH", + "Path to the JDK home used for retrieving type information about builtin Java types." + ) +end JavaSrc2Cpg diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala index 0b402019..6a2c8ed5 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala @@ -22,121 +22,113 @@ final case class Config( skipTypeInfPass: Boolean = false, dumpJavaparserAsts: Boolean = false ) extends X2CpgConfig[Config] - with TypeRecoveryParserConfig[Config] { - def withInferenceJarPaths(paths: Set[String]): Config = { - copy(inferenceJarPaths = paths).withInheritedFields(this) - } - - def withFetchDependencies(value: Boolean): Config = { - copy(fetchDependencies = value).withInheritedFields(this) - } - - def withJavaFeatureSetVersion(version: String): Config = { - copy(javaFeatureSetVersion = Some(version)).withInheritedFields(this) - } - - def withDelombokJavaHome(path: String): Config = { - copy(delombokJavaHome = Some(path)).withInheritedFields(this) - } - - def withDelombokMode(mode: String): Config = { - copy(delombokMode = Some(mode)).withInheritedFields(this) - } - - def withEnableTypeRecovery(value: Boolean): Config = { - copy(enableTypeRecovery = value).withInheritedFields(this) - } - - def withJdkPath(path: String): Config = { - copy(jdkPath = Some(path)).withInheritedFields(this) - } - - def withShowEnv(value: Boolean): Config = { - copy(showEnv = value).withInheritedFields(this) - } - - def withSkipTypeInfPass(value: Boolean): Config = { - copy(skipTypeInfPass = value).withInheritedFields(this) - } - - def withDumpJavaparserAsts(value: Boolean): Config = { - copy(dumpJavaparserAsts = value).withInheritedFields(this) - } -} - -private object Frontend { - - implicit val defaultConfig: Config = JavaSrc2Cpg.DefaultConfig - - val cmdLineParser: OParser[Unit, Config] = { - val builder = OParser.builder[Config] - import builder._ - OParser.sequence( - programName("javasrc2cpg"), - opt[String]("inference-jar-paths") - .text(s"extra jars used only for type information") - .action((path, c) => c.withInferenceJarPaths(c.inferenceJarPaths + path)), - opt[Unit]("fetch-dependencies") - .text("attempt to fetch dependencies jars for extra type information") - .action((_, c) => c.withFetchDependencies(true)), - opt[String]("delombok-java-home") - .text("Optional override to set java home used to run Delombok. Java 17 is recommended for the best results.") - .action((path, c) => c.withDelombokJavaHome(path)), - opt[String]("delombok-mode") - .text( - """Specifies how delombok should be executed. Options are + with TypeRecoveryParserConfig[Config]: + def withInferenceJarPaths(paths: Set[String]): Config = + copy(inferenceJarPaths = paths).withInheritedFields(this) + + def withFetchDependencies(value: Boolean): Config = + copy(fetchDependencies = value).withInheritedFields(this) + + def withJavaFeatureSetVersion(version: String): Config = + copy(javaFeatureSetVersion = Some(version)).withInheritedFields(this) + + def withDelombokJavaHome(path: String): Config = + copy(delombokJavaHome = Some(path)).withInheritedFields(this) + + def withDelombokMode(mode: String): Config = + copy(delombokMode = Some(mode)).withInheritedFields(this) + + def withEnableTypeRecovery(value: Boolean): Config = + copy(enableTypeRecovery = value).withInheritedFields(this) + + def withJdkPath(path: String): Config = + copy(jdkPath = Some(path)).withInheritedFields(this) + + def withShowEnv(value: Boolean): Config = + copy(showEnv = value).withInheritedFields(this) + + def withSkipTypeInfPass(value: Boolean): Config = + copy(skipTypeInfPass = value).withInheritedFields(this) + + def withDumpJavaparserAsts(value: Boolean): Config = + copy(dumpJavaparserAsts = value).withInheritedFields(this) +end Config + +private object Frontend: + + implicit val defaultConfig: Config = JavaSrc2Cpg.DefaultConfig + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("javasrc2cpg"), + opt[String]("inference-jar-paths") + .text(s"extra jars used only for type information") + .action((path, c) => c.withInferenceJarPaths(c.inferenceJarPaths + path)), + opt[Unit]("fetch-dependencies") + .text("attempt to fetch dependencies jars for extra type information") + .action((_, c) => c.withFetchDependencies(true)), + opt[String]("delombok-java-home") + .text( + "Optional override to set java home used to run Delombok. Java 17 is recommended for the best results." + ) + .action((path, c) => c.withDelombokJavaHome(path)), + opt[String]("delombok-mode") + .text( + """Specifies how delombok should be executed. Options are | no-delombok => do not use delombok for analysis or type information. | default => run delombok if a lombok dependency is found and analyse delomboked code. | types-only => run delombok, but use it for type information only | run-delombok => run delombok and use delomboked source for both analysis and type information.""".stripMargin + ) + .action((mode, c) => c.withDelombokMode(mode)), + opt[Unit]("enable-type-recovery") + .hidden() + .action((_, c) => c.withEnableTypeRecovery(true)) + .text("enable generic type recovery"), + XTypeRecovery.parserOptions, + opt[String]("jdk-path") + .action((path, c) => c.withJdkPath(path)) + .text( + "JDK used for resolving builtin Java types. If not set, current classpath will be used" + ), + opt[Unit]("show-env") + .action((_, c) => c.withShowEnv(true)) + .text("print information about environment variables used by javasrc2cpg and exit."), + opt[Unit]("skip-type-inf-pass") + .hidden() + .action((_, c) => c.withSkipTypeInfPass(true)) + .text( + "Skip the type inference pass. Results will be much worse, so should only be used for development purposes" + ), + opt[Unit]("dump-javaparser-asts") + .hidden() + .action((_, c) => c.withDumpJavaparserAsts(true)) + .text( + "Dump the javaparser asts for the given input files and terminate (for debugging)." + ) ) - .action((mode, c) => c.withDelombokMode(mode)), - opt[Unit]("enable-type-recovery") - .hidden() - .action((_, c) => c.withEnableTypeRecovery(true)) - .text("enable generic type recovery"), - XTypeRecovery.parserOptions, - opt[String]("jdk-path") - .action((path, c) => c.withJdkPath(path)) - .text("JDK used for resolving builtin Java types. If not set, current classpath will be used"), - opt[Unit]("show-env") - .action((_, c) => c.withShowEnv(true)) - .text("print information about environment variables used by javasrc2cpg and exit."), - opt[Unit]("skip-type-inf-pass") - .hidden() - .action((_, c) => c.withSkipTypeInfPass(true)) - .text( - "Skip the type inference pass. Results will be much worse, so should only be used for development purposes" - ), - opt[Unit]("dump-javaparser-asts") - .hidden() - .action((_, c) => c.withDumpJavaparserAsts(true)) - .text("Dump the javaparser asts for the given input files and terminate (for debugging).") - ) - } -} - -object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()) { - - override def main(args: Array[String]): Unit = { - // TODO: This is a hack to allow users to use the "--show-env" option without having - // to specify an input argument. Clean this up when adding this option to more frontends. - if (args.contains("--show-env")) { - super.main(Array("--show-env", "")) - } else { - super.main(args) - } - } - - def run(config: Config, javasrc2Cpg: JavaSrc2Cpg): Unit = { - if (config.showEnv) { - JavaSrc2Cpg.showEnv() - } else if (config.dumpJavaparserAsts) { - JavaParserAstPrinter.printJpAsts(config) - } else { - javasrc2Cpg.run(config) - } - } - - def getCmdLineParser: OParser[Unit, Config] = cmdLineParser -} + end cmdLineParser +end Frontend + +object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()): + + override def main(args: Array[String]): Unit = + // TODO: This is a hack to allow users to use the "--show-env" option without having + // to specify an input argument. Clean this up when adding this option to more frontends. + if args.contains("--show-env") then + super.main(Array("--show-env", "")) + else + super.main(args) + + def run(config: Config, javasrc2Cpg: JavaSrc2Cpg): Unit = + if config.showEnv then + JavaSrc2Cpg.showEnv() + else if config.dumpJavaparserAsts then + JavaParserAstPrinter.printJpAsts(config) + else + javasrc2Cpg.run(config) + + def getCmdLineParser: OParser[Unit, Config] = cmdLineParser +end Main diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala index 7294ab4d..7b7cf7d8 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala @@ -2,20 +2,20 @@ package io.appthreat.javasrc2cpg.jartypereader import io.appthreat.javasrc2cpg.jartypereader.descriptorparser.DescriptorParser import io.appthreat.javasrc2cpg.jartypereader.model.{ - ClassSignature, - ClassTypeSignature, - NameWithTypeArgs, - ResolvedMethod, - ResolvedTypeDecl, - ResolvedVariableType + ClassSignature, + ClassTypeSignature, + NameWithTypeArgs, + ResolvedMethod, + ResolvedTypeDecl, + ResolvedVariableType } import io.appthreat.javasrc2cpg.jartypereader.model.{ - ClassSignature, - ClassTypeSignature, - NameWithTypeArgs, - ResolvedMethod, - ResolvedTypeDecl, - ResolvedVariableType + ClassSignature, + ClassTypeSignature, + NameWithTypeArgs, + ResolvedMethod, + ResolvedTypeDecl, + ResolvedVariableType } import io.shiftleft.utils.ProjectRoot import javassist.{ClassPool, CtClass, CtField, CtMethod, NotFoundException} @@ -25,105 +25,105 @@ import java.util.jar.JarFile import scala.jdk.CollectionConverters.EnumerationHasAsScala import scala.util.{Failure, Success, Try} -object JarTypeReader { - - private val logger = LoggerFactory.getLogger(this.getClass) - - val ObjectTypeSignature: ClassTypeSignature = { - val packageSpecifier = Some("java.lang") - val signature = NameWithTypeArgs(name = "Object", typeArguments = Nil) - ClassTypeSignature(packageSpecifier, signature, suffix = Nil) - } - - def getTypes(jarPath: String): List[ResolvedTypeDecl] = { - val cp = new ClassPool() - cp.insertClassPath(jarPath) - - val jarFile = new JarFile(jarPath) - jarFile - .entries() - .asScala - .filter(_.getName.endsWith(".class")) - .map { entry => - entry.getRealName.replace("/", ".").dropRight(6) // Drop .class extension - } - .flatMap(getTypeDeclForEntry(cp, _)) - .toList - - } - - def getTypeEntryForMethod(method: CtMethod, parentDecl: ResolvedTypeDecl): ResolvedMethod = { - val name = method.getName - val signatureDescriptor = Option(method.getGenericSignature).getOrElse(method.getSignature) - val signature = DescriptorParser.parseMethodSignature(signatureDescriptor) - val isAbstract = method.isEmpty - - model.ResolvedMethod(name, parentDecl, signature, isAbstract) - } - - /** This should only be used for classes without a generic signature. Generic type information will be missing - * otherwise. - */ - private def classTypeSignatureFromString(signature: String): ClassTypeSignature = { - signature.split('.').toList match { - case Nil => - logger.debug(s"$signature is not a valid class signature") - ClassTypeSignature(None, NameWithTypeArgs("", Nil), Nil) - - case name :: Nil => ClassTypeSignature(None, typedName = NameWithTypeArgs(name, Nil), Nil) - - case nameWithPkg => - val packagePrefix = nameWithPkg.init.mkString(".") - val name = nameWithPkg.last - ClassTypeSignature(Some(packagePrefix), NameWithTypeArgs(name, Nil), Nil) - - } - } - - /** This should only be used for classes without a generic signature. Generic type information will be missing - * otherwise. - */ - private def getCtClassSignature(ctClass: CtClass): ClassSignature = { - val classFile = ctClass.getClassFile2 - val typeParameters = Nil - val superclassSignature = - Option.unless(classFile.isInterface)(classTypeSignatureFromString(classFile.getSuperclass)) - val interfacesSignatures = classFile.getInterfaces.map(classTypeSignatureFromString).toList - - ClassSignature(typeParameters, superclassSignature, interfacesSignatures) - } - - private def getResolvedField(ctField: CtField): ResolvedVariableType = { - val name = ctField.getName - val signatureDescriptor = Option(ctField.getGenericSignature).getOrElse(ctField.getSignature) - val signature = DescriptorParser.parseFieldSignature(signatureDescriptor) - - model.ResolvedVariableType(name, signature) - } - - def getTypeDeclForEntry(cp: ClassPool, name: String): Option[ResolvedTypeDecl] = { - Try(cp.get(name)) match { - case Success(ctClass) => - val name = ctClass.getSimpleName - val packageSpecifier = ctClass.getPackageName - val signature = Option(ctClass.getGenericSignature) - .map(DescriptorParser.parseClassSignature) - .getOrElse(getCtClassSignature(ctClass)) - val isInterface = ctClass.isInterface - val isAbstract = !isInterface && ctClass.getClassFile2.isAbstract - val fields = ctClass.getFields.map(getResolvedField).toList - val typeDecl = model.ResolvedTypeDecl(name, Some(packageSpecifier), signature, isInterface, isAbstract, fields) - val methods = ctClass.getMethods.map(getTypeEntryForMethod(_, typeDecl)).toList - typeDecl.addMethods(methods) - - Some(typeDecl) - - case Failure(_: NotFoundException) => - // TODO: Expected for interfaces, but fix this - None - - case Failure(_) => - None - } - } -} +object JarTypeReader: + + private val logger = LoggerFactory.getLogger(this.getClass) + + val ObjectTypeSignature: ClassTypeSignature = + val packageSpecifier = Some("java.lang") + val signature = NameWithTypeArgs(name = "Object", typeArguments = Nil) + ClassTypeSignature(packageSpecifier, signature, suffix = Nil) + + def getTypes(jarPath: String): List[ResolvedTypeDecl] = + val cp = new ClassPool() + cp.insertClassPath(jarPath) + + val jarFile = new JarFile(jarPath) + jarFile + .entries() + .asScala + .filter(_.getName.endsWith(".class")) + .map { entry => + entry.getRealName.replace("/", ".").dropRight(6) // Drop .class extension + } + .flatMap(getTypeDeclForEntry(cp, _)) + .toList + + def getTypeEntryForMethod(method: CtMethod, parentDecl: ResolvedTypeDecl): ResolvedMethod = + val name = method.getName + val signatureDescriptor = Option(method.getGenericSignature).getOrElse(method.getSignature) + val signature = DescriptorParser.parseMethodSignature(signatureDescriptor) + val isAbstract = method.isEmpty + + model.ResolvedMethod(name, parentDecl, signature, isAbstract) + + /** This should only be used for classes without a generic signature. Generic type information + * will be missing otherwise. + */ + private def classTypeSignatureFromString(signature: String): ClassTypeSignature = + signature.split('.').toList match + case Nil => + logger.debug(s"$signature is not a valid class signature") + ClassTypeSignature(None, NameWithTypeArgs("", Nil), Nil) + + case name :: Nil => + ClassTypeSignature(None, typedName = NameWithTypeArgs(name, Nil), Nil) + + case nameWithPkg => + val packagePrefix = nameWithPkg.init.mkString(".") + val name = nameWithPkg.last + ClassTypeSignature(Some(packagePrefix), NameWithTypeArgs(name, Nil), Nil) + + /** This should only be used for classes without a generic signature. Generic type information + * will be missing otherwise. + */ + private def getCtClassSignature(ctClass: CtClass): ClassSignature = + val classFile = ctClass.getClassFile2 + val typeParameters = Nil + val superclassSignature = + Option.unless(classFile.isInterface)( + classTypeSignatureFromString(classFile.getSuperclass) + ) + val interfacesSignatures = classFile.getInterfaces.map(classTypeSignatureFromString).toList + + ClassSignature(typeParameters, superclassSignature, interfacesSignatures) + + private def getResolvedField(ctField: CtField): ResolvedVariableType = + val name = ctField.getName + val signatureDescriptor = + Option(ctField.getGenericSignature).getOrElse(ctField.getSignature) + val signature = DescriptorParser.parseFieldSignature(signatureDescriptor) + + model.ResolvedVariableType(name, signature) + + def getTypeDeclForEntry(cp: ClassPool, name: String): Option[ResolvedTypeDecl] = + Try(cp.get(name)) match + case Success(ctClass) => + val name = ctClass.getSimpleName + val packageSpecifier = ctClass.getPackageName + val signature = Option(ctClass.getGenericSignature) + .map(DescriptorParser.parseClassSignature) + .getOrElse(getCtClassSignature(ctClass)) + val isInterface = ctClass.isInterface + val isAbstract = !isInterface && ctClass.getClassFile2.isAbstract + val fields = ctClass.getFields.map(getResolvedField).toList + val typeDecl = model.ResolvedTypeDecl( + name, + Some(packageSpecifier), + signature, + isInterface, + isAbstract, + fields + ) + val methods = ctClass.getMethods.map(getTypeEntryForMethod(_, typeDecl)).toList + typeDecl.addMethods(methods) + + Some(typeDecl) + + case Failure(_: NotFoundException) => + // TODO: Expected for interfaces, but fix this + None + + case Failure(_) => + None +end JarTypeReader diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala index 885259f3..dcc341e9 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala @@ -1,38 +1,41 @@ package io.appthreat.javasrc2cpg.jartypereader.descriptorparser -import io.appthreat.javasrc2cpg.jartypereader.model.{ClassSignature, MethodSignature, ReferenceTypeSignature} +import io.appthreat.javasrc2cpg.jartypereader.model.{ + ClassSignature, + MethodSignature, + ReferenceTypeSignature +} import org.slf4j.LoggerFactory import scala.util.Try import scala.util.parsing.combinator.{Parsers, RegexParsers} -object DescriptorParser extends TypeParser { +object DescriptorParser extends TypeParser: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - def parseMethodSignature(descriptor: String): MethodSignature = { - parseSignature(methodSignature, descriptor) - } + def parseMethodSignature(descriptor: String): MethodSignature = + parseSignature(methodSignature, descriptor) - def parseClassSignature(descriptor: String): ClassSignature = { - parseSignature(classSignature, descriptor) - } + def parseClassSignature(descriptor: String): ClassSignature = + parseSignature(classSignature, descriptor) - def parseFieldSignature(descriptor: String): ReferenceTypeSignature = { - parseSignature(fieldSignature, descriptor) - } + def parseFieldSignature(descriptor: String): ReferenceTypeSignature = + parseSignature(fieldSignature, descriptor) - private def parseSignature[T](parser: Parser[T], descriptor: String): T = { - parse(parser, descriptor) match { - case Success(signature, _) => signature + private def parseSignature[T](parser: Parser[T], descriptor: String): T = + parse(parser, descriptor) match + case Success(signature, _) => signature - case Failure(err, _) => - logger.debug(s"parseClassSignature failed with $err") - throw new IllegalArgumentException(s"FAILURE: Parsing invalid signature descriptor $descriptor") + case Failure(err, _) => + logger.debug(s"parseClassSignature failed with $err") + throw new IllegalArgumentException( + s"FAILURE: Parsing invalid signature descriptor $descriptor" + ) - case Error(err, _) => - logger.debug(s"parseClassSignature raised error $err") - throw new IllegalArgumentException(s"ERROR: Parsing invalid signature descriptor $descriptor") - } - } -} + case Error(err, _) => + logger.debug(s"parseClassSignature raised error $err") + throw new IllegalArgumentException( + s"ERROR: Parsing invalid signature descriptor $descriptor" + ) +end DescriptorParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala index b5d55efe..b82240f3 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala @@ -1,59 +1,60 @@ package io.appthreat.javasrc2cpg.jartypereader.descriptorparser import io.appthreat.javasrc2cpg.jartypereader.model.PrimitiveType -import io.appthreat.javasrc2cpg.jartypereader.model.Model.TypeConstants._ +import io.appthreat.javasrc2cpg.jartypereader.model.Model.TypeConstants.* import org.slf4j.LoggerFactory import scala.util.parsing.combinator.RegexParsers -trait TokenParser extends RegexParsers { - private val logger = LoggerFactory.getLogger(classOf[TokenParser]) +trait TokenParser extends RegexParsers: + private val logger = LoggerFactory.getLogger(classOf[TokenParser]) - case object Colon - case object Semicolon - case object Slash - case object ClassTypeStart - case object TypeVarStart - case object ArrayStart - case object Dot - case object OpenParen - case object CloseParen - case object OpenAngle - case object CloseAngle - case object Caret + case object Colon + case object Semicolon + case object Slash + case object ClassTypeStart + case object TypeVarStart + case object ArrayStart + case object Dot + case object OpenParen + case object CloseParen + case object OpenAngle + case object CloseAngle + case object Caret - private def translateBaseType(descriptor: String): String = { - descriptor match { - case "B" => Byte - case "C" => Char - case "D" => Double - case "F" => Float - case "I" => Int - case "J" => Long - case "S" => Short - case "Z" => Boolean - // Void is not a BaseType, but sort of acts like one in method return signatures. Since - // we expect valid descriptors (and types in general) as input, treating void as a base - // type simplifies the model without introducing a significant possibility for confusion or error. - case "V" => Void - case unk => throw new IllegalArgumentException(s"$unk is not a valid base type descriptor.") - } - } + private def translateBaseType(descriptor: String): String = + descriptor match + case "B" => Byte + case "C" => Char + case "D" => Double + case "F" => Float + case "I" => Int + case "J" => Long + case "S" => Short + case "Z" => Boolean + // Void is not a BaseType, but sort of acts like one in method return signatures. Since + // we expect valid descriptors (and types in general) as input, treating void as a base + // type simplifies the model without introducing a significant possibility for confusion or error. + case "V" => Void + case unk => + throw new IllegalArgumentException(s"$unk is not a valid base type descriptor.") - def baseType: Parser[PrimitiveType] = "[BCDFIJSZ]".r ^^ { t => PrimitiveType(translateBaseType(t)) } - def voidDescriptor: Parser[PrimitiveType] = "V" ^^ { t => PrimitiveType(translateBaseType(t)) } - def identifier: Parser[String] = raw"[^.\[/<>:;]+".r ^^ identity - def wildcardIndicator: Parser[String] = "[+-]".r ^^ identity - def colon: Parser[Colon.type] = ":" ^^ (_ => Colon) - def semicolon: Parser[Semicolon.type] = ";" ^^ (_ => Semicolon) - def slash: Parser[Slash.type] = "/" ^^ (_ => Slash) - def classTypeStart: Parser[ClassTypeStart.type] = "L" ^^ (_ => ClassTypeStart) - def typeVarStart: Parser[TypeVarStart.type] = "T" ^^ (_ => TypeVarStart) - def arrayStart: Parser[ArrayStart.type] = "[" ^^ (_ => ArrayStart) - def dot: Parser[Dot.type] = "." ^^ (_ => Dot) - def openParen: Parser[OpenParen.type] = "(" ^^ (_ => OpenParen) - def closeParen: Parser[CloseParen.type] = ")" ^^ (_ => CloseParen) - def openAngle: Parser[OpenAngle.type] = "<" ^^ (_ => OpenAngle) - def closeAngle: Parser[CloseAngle.type] = ">" ^^ (_ => CloseAngle) - def caret: Parser[Caret.type] = "^" ^^ (_ => Caret) -} + def baseType: Parser[PrimitiveType] = "[BCDFIJSZ]".r ^^ { t => + PrimitiveType(translateBaseType(t)) + } + def voidDescriptor: Parser[PrimitiveType] = "V" ^^ { t => PrimitiveType(translateBaseType(t)) } + def identifier: Parser[String] = raw"[^.\[/<>:;]+".r ^^ identity + def wildcardIndicator: Parser[String] = "[+-]".r ^^ identity + def colon: Parser[Colon.type] = ":" ^^ (_ => Colon) + def semicolon: Parser[Semicolon.type] = ";" ^^ (_ => Semicolon) + def slash: Parser[Slash.type] = "/" ^^ (_ => Slash) + def classTypeStart: Parser[ClassTypeStart.type] = "L" ^^ (_ => ClassTypeStart) + def typeVarStart: Parser[TypeVarStart.type] = "T" ^^ (_ => TypeVarStart) + def arrayStart: Parser[ArrayStart.type] = "[" ^^ (_ => ArrayStart) + def dot: Parser[Dot.type] = "." ^^ (_ => Dot) + def openParen: Parser[OpenParen.type] = "(" ^^ (_ => OpenParen) + def closeParen: Parser[CloseParen.type] = ")" ^^ (_ => CloseParen) + def openAngle: Parser[OpenAngle.type] = "<" ^^ (_ => OpenAngle) + def closeAngle: Parser[CloseAngle.type] = ">" ^^ (_ => CloseAngle) + def caret: Parser[Caret.type] = "^" ^^ (_ => Caret) +end TokenParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala index aa8536ff..ac6826f9 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala @@ -3,121 +3,118 @@ package io.appthreat.javasrc2cpg.jartypereader.descriptorparser import io.appthreat.javasrc2cpg.jartypereader.model import io.appthreat.javasrc2cpg.jartypereader.model.Bound.{BoundAbove, BoundBelow} import io.appthreat.javasrc2cpg.jartypereader.model.{ - ArrayTypeSignature, - BoundWildcard, - ClassSignature, - ClassTypeSignature, - JavaTypeSignature, - MethodSignature, - NameWithTypeArgs, - ReferenceTypeSignature, - SimpleTypeArgument, - TypeArgument, - TypeParameter, - TypeVariableSignature, - UnboundWildcard + ArrayTypeSignature, + BoundWildcard, + ClassSignature, + ClassTypeSignature, + JavaTypeSignature, + MethodSignature, + NameWithTypeArgs, + ReferenceTypeSignature, + SimpleTypeArgument, + TypeArgument, + TypeParameter, + TypeVariableSignature, + UnboundWildcard } import io.appthreat.javasrc2cpg.jartypereader.model.* import org.slf4j.LoggerFactory -trait TypeParser extends TokenParser { +trait TypeParser extends TokenParser: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - def unboundWildcard: Parser[UnboundWildcard.type] = "*" ^^ (_ => UnboundWildcard) + def unboundWildcard: Parser[UnboundWildcard.type] = "*" ^^ (_ => UnboundWildcard) - def typeVarSignature: Parser[TypeVariableSignature] = { - (typeVarStart ~ identifier ~ semicolon) ^^ { case _ ~ name ~ _ => TypeVariableSignature(name) } - } + def typeVarSignature: Parser[TypeVariableSignature] = + (typeVarStart ~ identifier ~ semicolon) ^^ { case _ ~ name ~ _ => + TypeVariableSignature(name) + } - def packageSpecifier: Parser[String] = { - // TODO Expect these to be short, but potential room for optimizing string concatenation if performance is an issue. - (identifier ~ slash ~ rep(packageSpecifier)) ^^ { - case firstId ~ _ ~ Nil => firstId - case firstId ~ _ ~ otherSpecifiers => firstId + "." + otherSpecifiers.mkString - } - } + def packageSpecifier: Parser[String] = + // TODO Expect these to be short, but potential room for optimizing string concatenation if performance is an issue. + (identifier ~ slash ~ rep(packageSpecifier)) ^^ { + case firstId ~ _ ~ Nil => firstId + case firstId ~ _ ~ otherSpecifiers => firstId + "." + otherSpecifiers.mkString + } - def typeArguments: Parser[List[TypeArgument]] = { - (openAngle ~ rep1(typeArgument) ~ closeAngle) ^^ { case _ ~ typeArgs ~ _ => typeArgs } - } + def typeArguments: Parser[List[TypeArgument]] = + (openAngle ~ rep1(typeArgument) ~ closeAngle) ^^ { case _ ~ typeArgs ~ _ => typeArgs } - def typeArgument: Parser[TypeArgument] = { - val maybeBoundTypeArgument = (opt(wildcardIndicator) ~ referenceTypeSignature) ^^ { - case Some("-") ~ typeSignature => BoundWildcard(BoundBelow, typeSignature) + def typeArgument: Parser[TypeArgument] = + val maybeBoundTypeArgument = (opt(wildcardIndicator) ~ referenceTypeSignature) ^^ { + case Some("-") ~ typeSignature => BoundWildcard(BoundBelow, typeSignature) - case Some("+") ~ typeSignature => model.BoundWildcard(BoundAbove, typeSignature) + case Some("+") ~ typeSignature => model.BoundWildcard(BoundAbove, typeSignature) - case Some(symbol) ~ typeSignature => - logger.debug(s"Invalid wildcard indicator `$symbol`. Treating as unbound wildcard") - UnboundWildcard + case Some(symbol) ~ typeSignature => + logger.debug(s"Invalid wildcard indicator `$symbol`. Treating as unbound wildcard") + UnboundWildcard - case None ~ typeSignature => SimpleTypeArgument(typeSignature) - } + case None ~ typeSignature => SimpleTypeArgument(typeSignature) + } - maybeBoundTypeArgument | unboundWildcard - } + maybeBoundTypeArgument | unboundWildcard - def simpleClassTypeSignature: Parser[NameWithTypeArgs] = { - (identifier ~ opt(typeArguments)) ^^ { case name ~ maybeTypes => - model.NameWithTypeArgs(name, maybeTypes.getOrElse(Nil)) - } - } + def simpleClassTypeSignature: Parser[NameWithTypeArgs] = + (identifier ~ opt(typeArguments)) ^^ { case name ~ maybeTypes => + model.NameWithTypeArgs(name, maybeTypes.getOrElse(Nil)) + } - def classTypeSignatureSuffix: Parser[NameWithTypeArgs] = { - (dot ~ simpleClassTypeSignature) ^^ { case _ ~ sig => sig } - } + def classTypeSignatureSuffix: Parser[NameWithTypeArgs] = + (dot ~ simpleClassTypeSignature) ^^ { case _ ~ sig => sig } - def classTypeSignature: Parser[ClassTypeSignature] = { - (classTypeStart ~ opt(packageSpecifier) ~ simpleClassTypeSignature ~ rep(classTypeSignatureSuffix) ~ semicolon) ^^ { - case _ ~ packageSpecifier ~ signature ~ suffix ~ _ => - ClassTypeSignature(packageSpecifier, signature, suffix) - } - } + def classTypeSignature: Parser[ClassTypeSignature] = + (classTypeStart ~ opt(packageSpecifier) ~ simpleClassTypeSignature ~ rep( + classTypeSignatureSuffix + ) ~ semicolon) ^^ { + case _ ~ packageSpecifier ~ signature ~ suffix ~ _ => + ClassTypeSignature(packageSpecifier, signature, suffix) + } - def arrayTypeSignature: Parser[ArrayTypeSignature] = { - (arrayStart ~ javaTypeSignature) ^^ { case _ ~ sig => ArrayTypeSignature(sig) } - } + def arrayTypeSignature: Parser[ArrayTypeSignature] = + (arrayStart ~ javaTypeSignature) ^^ { case _ ~ sig => ArrayTypeSignature(sig) } - def javaTypeSignature: Parser[JavaTypeSignature] = referenceTypeSignature | baseType + def javaTypeSignature: Parser[JavaTypeSignature] = referenceTypeSignature | baseType - def returnType: Parser[JavaTypeSignature] = javaTypeSignature | voidDescriptor + def returnType: Parser[JavaTypeSignature] = javaTypeSignature | voidDescriptor - def referenceTypeSignature: Parser[ReferenceTypeSignature] = { - classTypeSignature | typeVarSignature | arrayTypeSignature - } + def referenceTypeSignature: Parser[ReferenceTypeSignature] = + classTypeSignature | typeVarSignature | arrayTypeSignature - def classBound: Parser[Option[ReferenceTypeSignature]] = { - (colon ~ opt(referenceTypeSignature)) ^^ { case _ ~ maybeSignature => maybeSignature } - } + def classBound: Parser[Option[ReferenceTypeSignature]] = + (colon ~ opt(referenceTypeSignature)) ^^ { case _ ~ maybeSignature => maybeSignature } - def interfaceBound: Parser[ReferenceTypeSignature] = { - (colon ~ referenceTypeSignature) ^^ { case _ ~ signature => signature } - } + def interfaceBound: Parser[ReferenceTypeSignature] = + (colon ~ referenceTypeSignature) ^^ { case _ ~ signature => signature } - def typeParameter: Parser[TypeParameter] = { - (identifier ~ classBound ~ rep(interfaceBound)) ^^ { case name ~ classBound ~ interfaceBounds => - TypeParameter(name, classBound, interfaceBounds) - } - } + def typeParameter: Parser[TypeParameter] = + (identifier ~ classBound ~ rep(interfaceBound)) ^^ { + case name ~ classBound ~ interfaceBounds => + TypeParameter(name, classBound, interfaceBounds) + } - def typeParameters: Parser[List[TypeParameter]] = { - (openAngle ~ rep1(typeParameter) ~ closeAngle) ^^ { case _ ~ params ~ _ => params } - } + def typeParameters: Parser[List[TypeParameter]] = + (openAngle ~ rep1(typeParameter) ~ closeAngle) ^^ { case _ ~ params ~ _ => params } - def classSignature: Parser[ClassSignature] = { - (opt(typeParameters) ~ classTypeSignature ~ rep(classTypeSignature)) ^^ { - case typeParams ~ supClass ~ supInterfaces => - ClassSignature(typeParams.getOrElse(Nil), Some(supClass), supInterfaces) - } - } + def classSignature: Parser[ClassSignature] = + (opt(typeParameters) ~ classTypeSignature ~ rep(classTypeSignature)) ^^ { + case typeParams ~ supClass ~ supInterfaces => + ClassSignature(typeParams.getOrElse(Nil), Some(supClass), supInterfaces) + } - def methodSignature: Parser[MethodSignature] = { - (opt(typeParameters) ~ openParen ~ rep(javaTypeSignature) ~ closeParen ~ returnType ~ rep(javaTypeSignature)) ^^ { - case typeParams ~ _ ~ paramTypes ~ _ ~ returnSignature ~ throwsSignatures => - MethodSignature(typeParams.getOrElse(Nil), paramTypes, returnSignature, throwsSignatures) - } - } + def methodSignature: Parser[MethodSignature] = + (opt(typeParameters) ~ openParen ~ rep(javaTypeSignature) ~ closeParen ~ returnType ~ rep( + javaTypeSignature + )) ^^ { + case typeParams ~ _ ~ paramTypes ~ _ ~ returnSignature ~ throwsSignatures => + MethodSignature( + typeParams.getOrElse(Nil), + paramTypes, + returnSignature, + throwsSignatures + ) + } - def fieldSignature: Parser[ReferenceTypeSignature] = referenceTypeSignature -} + def fieldSignature: Parser[ReferenceTypeSignature] = referenceTypeSignature +end TypeParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala index 5c72b768..093ab18c 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala @@ -2,30 +2,25 @@ package io.appthreat.javasrc2cpg.jartypereader.model import Bound.Bound -sealed trait Named { - def name: String - def qualifiedName: String - - override def equals(obj: Any): Boolean = { - obj match { - case other: Named => - this.name == other.name && this.qualifiedName == other.qualifiedName - case _ => false - } - } - - def buildQualifiedClassName(name: String, packageSpecifier: Option[String]): String = { - packageSpecifier.map(ps => s"$ps.$name").getOrElse(name) - } -} +sealed trait Named: + def name: String + def qualifiedName: String + + override def equals(obj: Any): Boolean = + obj match + case other: Named => + this.name == other.name && this.qualifiedName == other.qualifiedName + case _ => false + + def buildQualifiedClassName(name: String, packageSpecifier: Option[String]): String = + packageSpecifier.map(ps => s"$ps.$name").getOrElse(name) case class NameWithTypeArgs(name: String, typeArguments: List[TypeArgument]) -object Bound { - sealed trait Bound - case object BoundAbove extends Bound - case object BoundBelow extends Bound -} +object Bound: + sealed trait Bound + case object BoundAbove extends Bound + case object BoundBelow extends Bound sealed trait TypeArgument case class SimpleTypeArgument(typeSignature: ReferenceTypeSignature) extends TypeArgument @@ -33,10 +28,9 @@ case class BoundWildcard(bound: Bound, typeSignature: ReferenceTypeSignature) ex case object UnboundWildcard extends TypeArgument sealed trait JavaTypeSignature extends Named -case class PrimitiveType(fullName: String) extends JavaTypeSignature { - override val name: String = fullName - override val qualifiedName: String = fullName -} +case class PrimitiveType(fullName: String) extends JavaTypeSignature: + override val name: String = fullName + override val qualifiedName: String = fullName sealed trait ReferenceTypeSignature extends JavaTypeSignature @@ -44,20 +38,17 @@ case class ClassTypeSignature( packageSpecifier: Option[String], typedName: NameWithTypeArgs, suffix: List[NameWithTypeArgs] -) extends ReferenceTypeSignature { - override val name: String = (typedName :: suffix).map(_.name).mkString(".") - override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) -} - -case class TypeVariableSignature(identifier: String) extends ReferenceTypeSignature { - override val name: String = identifier - override val qualifiedName: String = identifier -} -case class ArrayTypeSignature(signature: JavaTypeSignature) extends ReferenceTypeSignature { - private def buildArrayName(baseName: String): String = s"[$baseName]" - override val name: String = buildArrayName(signature.name) - override val qualifiedName: String = buildArrayName(signature.qualifiedName) -} +) extends ReferenceTypeSignature: + override val name: String = (typedName :: suffix).map(_.name).mkString(".") + override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) + +case class TypeVariableSignature(identifier: String) extends ReferenceTypeSignature: + override val name: String = identifier + override val qualifiedName: String = identifier +case class ArrayTypeSignature(signature: JavaTypeSignature) extends ReferenceTypeSignature: + private def buildArrayName(baseName: String): String = s"[$baseName]" + override val name: String = buildArrayName(signature.name) + override val qualifiedName: String = buildArrayName(signature.qualifiedName) case class TypeParameter( name: String, @@ -78,10 +69,9 @@ case class MethodSignature( sealed trait ResolvedType extends Named -object Unresolved extends ResolvedType { - override val name: String = "Unresolved" - override val qualifiedName: String = name -} +object Unresolved extends ResolvedType: + override val name: String = "Unresolved" + override val qualifiedName: String = name class ResolvedTypeDecl( override val name: String, @@ -91,60 +81,61 @@ class ResolvedTypeDecl( val isAbstract: Boolean, val fields: List[ResolvedVariableType], initDeclaredMethods: List[ResolvedMethod] -) extends ResolvedType { - override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) - - private var declaredMethods = initDeclaredMethods - - def getDeclaredMethods: List[ResolvedMethod] = declaredMethods - - private[jartypereader] def addMethods(newMethods: List[ResolvedMethod]): Unit = { - declaredMethods ++= newMethods - } -} -object ResolvedTypeDecl { - def apply( - name: String, - packageSpecifier: Option[String], - signature: ClassSignature, - isInterface: Boolean, - isAbstract: Boolean, - fields: List[ResolvedVariableType], - methods: List[ResolvedMethod] = Nil - ): ResolvedTypeDecl = { - new ResolvedTypeDecl(name, packageSpecifier, signature, isInterface, isAbstract, fields, methods) - } -} +) extends ResolvedType: + override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) + + private var declaredMethods = initDeclaredMethods + + def getDeclaredMethods: List[ResolvedMethod] = declaredMethods + + private[jartypereader] def addMethods(newMethods: List[ResolvedMethod]): Unit = + declaredMethods ++= newMethods +object ResolvedTypeDecl: + def apply( + name: String, + packageSpecifier: Option[String], + signature: ClassSignature, + isInterface: Boolean, + isAbstract: Boolean, + fields: List[ResolvedVariableType], + methods: List[ResolvedMethod] = Nil + ): ResolvedTypeDecl = + new ResolvedTypeDecl( + name, + packageSpecifier, + signature, + isInterface, + isAbstract, + fields, + methods + ) case class ResolvedMethod( override val name: String, parentTypeDecl: ResolvedTypeDecl, signature: MethodSignature, isAbstract: Boolean -) extends ResolvedType { - override val qualifiedName: String = s"${parentTypeDecl.qualifiedName}.$name" -} - -case class ResolvedVariableType(name: String, signature: ReferenceTypeSignature) extends ResolvedType { - override val qualifiedName: String = name -} - -object Model { - // TODO: This is a duplicate of the TypeConstants object in TypeInfoCalculator. Remove the other one once - // we switch to the new solver. - object TypeConstants { - val Byte: String = "byte" - val Short: String = "short" - val Int: String = "int" - val Long: String = "long" - val Float: String = "float" - val Double: String = "double" - val Char: String = "char" - val Boolean: String = "boolean" - val Object: String = "java.lang.Object" - val Class: String = "java.lang.Class" - val Iterator: String = "java.util.Iterator" - val Void: String = "void" - val Any: String = "ANY" - } -} +) extends ResolvedType: + override val qualifiedName: String = s"${parentTypeDecl.qualifiedName}.$name" + +case class ResolvedVariableType(name: String, signature: ReferenceTypeSignature) + extends ResolvedType: + override val qualifiedName: String = name + +object Model: + // TODO: This is a duplicate of the TypeConstants object in TypeInfoCalculator. Remove the other one once + // we switch to the new solver. + object TypeConstants: + val Byte: String = "byte" + val Short: String = "short" + val Int: String = "int" + val Long: String = "long" + val Float: String = "float" + val Double: String = "double" + val Char: String = "char" + val Boolean: String = "boolean" + val Object: String = "java.lang.Object" + val Class: String = "java.lang.Class" + val Iterator: String = "java.util.Iterator" + val Void: String = "void" + val Any: String = "ANY" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala index d0660bb4..ff0e5df2 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala @@ -8,18 +8,16 @@ import io.shiftleft.semanticcpg.language.dotextension.Shared import java.nio.file.Path -object JavaParserAstPrinter { - def printJpAsts(config: Config): Unit = { +object JavaParserAstPrinter: + def printJpAsts(config: Config): Unit = - val sourceParser = util.SourceParser(config, false) - val printer = new YamlPrinter(true) + val sourceParser = util.SourceParser(config, false) + val printer = new YamlPrinter(true) - SourceParser.getSourceFilenames(config).foreach { filename => - val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString - sourceParser.parseAnalysisFile(relativeFilename).foreach { compilationUnit => - println(relativeFilename) - println(printer.output(compilationUnit)) - } - } - } -} + SourceParser.getSourceFilenames(config).foreach { filename => + val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString + sourceParser.parseAnalysisFile(relativeFilename).foreach { compilationUnit => + println(relativeFilename) + println(printer.output(compilationUnit)) + } + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala index af8e6e7f..e09fe0e8 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala @@ -3,7 +3,10 @@ package io.appthreat.javasrc2cpg.passes import better.files.File import com.github.javaparser.ast.CompilationUnit import com.github.javaparser.symbolsolver.JavaSymbolSolver -import com.github.javaparser.symbolsolver.resolution.typesolvers.{JarTypeSolver, ReflectionTypeSolver} +import com.github.javaparser.symbolsolver.resolution.typesolvers.{ + JarTypeSolver, + ReflectionTypeSolver +} import io.appthreat.javasrc2cpg.JavaSrc2Cpg.JavaSrcEnvVar import io.appthreat.javasrc2cpg.typesolvers.noncaching.JdkJarTypeSolver import io.appthreat.javasrc2cpg.typesolvers.{EagerSourceTypeSolver, SimpleCombinedTypeSolver} @@ -20,118 +23,128 @@ import scala.collection.mutable import scala.util.{Success, Try} class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[String]] = None) - extends ConcurrentWriterCpgPass[String](cpg) { - - val global: Global = new Global() - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - - private val sourceFilenames = SourceParser.getSourceFilenames(config, sourcesOverride) - - val (sourceParser, symbolSolver, packagesJarMappings) = initParserAndUtils(config, sourceFilenames) - - override def generateParts(): Array[String] = sourceFilenames - - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = { - val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString - sourceParser.parseAnalysisFile(relativeFilename) match { - case Some(compilationUnit) => - symbolSolver.inject(compilationUnit) - diffGraph.absorb( - new AstCreator(relativeFilename, compilationUnit, global, symbolSolver, packagesJarMappings)( - config.schemaValidation - ).createAst() - ) - - case None => logger.debug(s"Skipping AST creation for $filename") - } - } - - private def initParserAndUtils( - config: Config, - sourceFilenames: Array[String] - ): (SourceParser, JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = { - val dependencies = getDependencyList(config.inputPath) - val sourceParser = util.SourceParser(config, dependencies.exists(_.contains("lombok"))) - val (symbolSolver, packagesJarMappings) = createSymbolSolver(config, dependencies, sourceParser, sourceFilenames) - (sourceParser, symbolSolver, packagesJarMappings) - } - - private def getDependencyList(inputPath: String): List[String] = { - if (config.fetchDependencies) { - DependencyResolver.getDependencies(Paths.get(inputPath)) match { - case Some(deps) => deps.toList - case None => - logger.debug(s"Could not fetch dependencies for project at path $inputPath") - List() - } - } else { - logger.debug("dependency resolving disabled") - List() - } - } - - private def createSymbolSolver( - config: Config, - dependencies: List[String], - sourceParser: SourceParser, - sourceFilenames: Array[String] - ): (JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = { - val combinedTypeSolver = new SimpleCombinedTypeSolver() - val symbolSolver = new JavaSymbolSolver(combinedTypeSolver) - - val jdkPathFromEnvVar = Option(System.getenv(JavaSrcEnvVar.JdkPath.name)) - val jdkPath = (config.jdkPath, jdkPathFromEnvVar) match { - case (None, None) => - val javaHome = System.getProperty("java.home") - logger.debug(s"No explicit jdk-path set in , so using system java.home for JDK type information: $javaHome") - javaHome - - case (None, Some(jdkPath)) => - logger.debug( - s"Using JDK path from environment variable ${JavaSrcEnvVar.JdkPath.name} for JDK type information: $jdkPath" - ) - jdkPath - - case (Some(jdkPath), _) => - logger.debug(s"Using JDK path set with jdk-path option for JDK type information: $jdkPath") - jdkPath - } - - val jdkJarTypeSolver = JdkJarTypeSolver.fromJdkPath(jdkPath) - combinedTypeSolver.addNonCachingTypeSolver(jdkJarTypeSolver) - - val relativeSourceFilenames = - sourceFilenames.map(filename => Path.of(config.inputPath).relativize(Path.of(filename)).toString) - - val sourceTypeSolver = - EagerSourceTypeSolver(relativeSourceFilenames, sourceParser, combinedTypeSolver, symbolSolver) - - combinedTypeSolver.addCachingTypeSolver(sourceTypeSolver) - combinedTypeSolver.addNonCachingTypeSolver(new ReflectionTypeSolver()) - // Add solvers for inference jars - val jarsList = config.inferenceJarPaths.flatMap(recursiveJarsFromPath).toList - (jarsList ++ dependencies) - .flatMap { path => - Try(new JarTypeSolver(path)).toOption - } - .foreach { combinedTypeSolver.addNonCachingTypeSolver(_) } - - (symbolSolver, jdkJarTypeSolver.packagesJarMappings) - } - - private def recursiveJarsFromPath(path: String): List[String] = { - Try(File(path)) match { - case Success(file) if file.isDirectory => - file.listRecursively - .map(_.canonicalPath) - .filter(_.endsWith(".jar")) - .toList - - case Success(file) if file.canonicalPath.endsWith(".jar") => - List(file.canonicalPath) - - case _ => - Nil - } - } -} + extends ConcurrentWriterCpgPass[String](cpg): + + val global: Global = new Global() + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + + private val sourceFilenames = SourceParser.getSourceFilenames(config, sourcesOverride) + + val (sourceParser, symbolSolver, packagesJarMappings) = + initParserAndUtils(config, sourceFilenames) + + override def generateParts(): Array[String] = sourceFilenames + + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = + val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString + sourceParser.parseAnalysisFile(relativeFilename) match + case Some(compilationUnit) => + symbolSolver.inject(compilationUnit) + diffGraph.absorb( + new AstCreator( + relativeFilename, + compilationUnit, + global, + symbolSolver, + packagesJarMappings + )( + config.schemaValidation + ).createAst() + ) + + case None => logger.debug(s"Skipping AST creation for $filename") + + private def initParserAndUtils( + config: Config, + sourceFilenames: Array[String] + ): (SourceParser, JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = + val dependencies = getDependencyList(config.inputPath) + val sourceParser = util.SourceParser(config, dependencies.exists(_.contains("lombok"))) + val (symbolSolver, packagesJarMappings) = + createSymbolSolver(config, dependencies, sourceParser, sourceFilenames) + (sourceParser, symbolSolver, packagesJarMappings) + + private def getDependencyList(inputPath: String): List[String] = + if config.fetchDependencies then + DependencyResolver.getDependencies(Paths.get(inputPath)) match + case Some(deps) => deps.toList + case None => + logger.debug(s"Could not fetch dependencies for project at path $inputPath") + List() + else + logger.debug("dependency resolving disabled") + List() + + private def createSymbolSolver( + config: Config, + dependencies: List[String], + sourceParser: SourceParser, + sourceFilenames: Array[String] + ): (JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = + val combinedTypeSolver = new SimpleCombinedTypeSolver() + val symbolSolver = new JavaSymbolSolver(combinedTypeSolver) + + val jdkPathFromEnvVar = Option(System.getenv(JavaSrcEnvVar.JdkPath.name)) + val jdkPath = (config.jdkPath, jdkPathFromEnvVar) match + case (None, None) => + val javaHome = System.getProperty("java.home") + logger.debug( + s"No explicit jdk-path set in , so using system java.home for JDK type information: $javaHome" + ) + javaHome + + case (None, Some(jdkPath)) => + logger.debug( + s"Using JDK path from environment variable ${JavaSrcEnvVar.JdkPath.name} for JDK type information: $jdkPath" + ) + jdkPath + + case (Some(jdkPath), _) => + logger.debug( + s"Using JDK path set with jdk-path option for JDK type information: $jdkPath" + ) + jdkPath + + val jdkJarTypeSolver = JdkJarTypeSolver.fromJdkPath(jdkPath) + combinedTypeSolver.addNonCachingTypeSolver(jdkJarTypeSolver) + + val relativeSourceFilenames = + sourceFilenames.map(filename => + Path.of(config.inputPath).relativize(Path.of(filename)).toString + ) + + val sourceTypeSolver = + EagerSourceTypeSolver( + relativeSourceFilenames, + sourceParser, + combinedTypeSolver, + symbolSolver + ) + + combinedTypeSolver.addCachingTypeSolver(sourceTypeSolver) + combinedTypeSolver.addNonCachingTypeSolver(new ReflectionTypeSolver()) + // Add solvers for inference jars + val jarsList = config.inferenceJarPaths.flatMap(recursiveJarsFromPath).toList + (jarsList ++ dependencies) + .flatMap { path => + Try(new JarTypeSolver(path)).toOption + } + .foreach { combinedTypeSolver.addNonCachingTypeSolver(_) } + + (symbolSolver, jdkJarTypeSolver.packagesJarMappings) + end createSymbolSolver + + private def recursiveJarsFromPath(path: String): List[String] = + Try(File(path)) match + case Success(file) if file.isDirectory => + file.listRecursively + .map(_.canonicalPath) + .filter(_.endsWith(".jar")) + .toList + + case Success(file) if file.canonicalPath.endsWith(".jar") => + List(file.canonicalPath) + + case _ => + Nil +end AstCreationPass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala index 990523d5..6368d473 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala @@ -3,152 +3,162 @@ package io.appthreat.javasrc2cpg.passes import com.github.javaparser.ast.`type`.TypeParameter import com.github.javaparser.ast.{CompilationUnit, Node, NodeList, PackageDeclaration} import com.github.javaparser.ast.body.{ - AnnotationDeclaration, - BodyDeclaration, - CallableDeclaration, - ClassOrInterfaceDeclaration, - ConstructorDeclaration, - EnumConstantDeclaration, - FieldDeclaration, - InitializerDeclaration, - MethodDeclaration, - Parameter, - TypeDeclaration, - VariableDeclarator + AnnotationDeclaration, + BodyDeclaration, + CallableDeclaration, + ClassOrInterfaceDeclaration, + ConstructorDeclaration, + EnumConstantDeclaration, + FieldDeclaration, + InitializerDeclaration, + MethodDeclaration, + Parameter, + TypeDeclaration, + VariableDeclarator } import com.github.javaparser.ast.expr.AssignExpr.Operator import com.github.javaparser.ast.expr.{ - AnnotationExpr, - ArrayAccessExpr, - ArrayCreationExpr, - ArrayInitializerExpr, - AssignExpr, - BinaryExpr, - BooleanLiteralExpr, - CastExpr, - CharLiteralExpr, - ClassExpr, - ConditionalExpr, - DoubleLiteralExpr, - EnclosedExpr, - Expression, - FieldAccessExpr, - InstanceOfExpr, - IntegerLiteralExpr, - LambdaExpr, - LiteralExpr, - LongLiteralExpr, - MarkerAnnotationExpr, - MethodCallExpr, - NameExpr, - NormalAnnotationExpr, - NullLiteralExpr, - ObjectCreationExpr, - SingleMemberAnnotationExpr, - StringLiteralExpr, - SuperExpr, - TextBlockLiteralExpr, - ThisExpr, - UnaryExpr, - VariableDeclarationExpr + AnnotationExpr, + ArrayAccessExpr, + ArrayCreationExpr, + ArrayInitializerExpr, + AssignExpr, + BinaryExpr, + BooleanLiteralExpr, + CastExpr, + CharLiteralExpr, + ClassExpr, + ConditionalExpr, + DoubleLiteralExpr, + EnclosedExpr, + Expression, + FieldAccessExpr, + InstanceOfExpr, + IntegerLiteralExpr, + LambdaExpr, + LiteralExpr, + LongLiteralExpr, + MarkerAnnotationExpr, + MethodCallExpr, + NameExpr, + NormalAnnotationExpr, + NullLiteralExpr, + ObjectCreationExpr, + SingleMemberAnnotationExpr, + StringLiteralExpr, + SuperExpr, + TextBlockLiteralExpr, + ThisExpr, + UnaryExpr, + VariableDeclarationExpr } import com.github.javaparser.ast.nodeTypes.{NodeWithName, NodeWithSimpleName} import com.github.javaparser.ast.stmt.{ - AssertStmt, - BlockStmt, - BreakStmt, - CatchClause, - ContinueStmt, - DoStmt, - EmptyStmt, - ExplicitConstructorInvocationStmt, - ExpressionStmt, - ForEachStmt, - ForStmt, - IfStmt, - LabeledStmt, - ReturnStmt, - Statement, - SwitchEntry, - SwitchStmt, - SynchronizedStmt, - ThrowStmt, - TryStmt, - WhileStmt + AssertStmt, + BlockStmt, + BreakStmt, + CatchClause, + ContinueStmt, + DoStmt, + EmptyStmt, + ExplicitConstructorInvocationStmt, + ExpressionStmt, + ForEachStmt, + ForStmt, + IfStmt, + LabeledStmt, + ReturnStmt, + Statement, + SwitchEntry, + SwitchStmt, + SynchronizedStmt, + ThrowStmt, + TryStmt, + WhileStmt } import com.github.javaparser.resolution.UnsolvedSymbolException import com.github.javaparser.resolution.declarations.{ - ResolvedFieldDeclaration, - ResolvedMethodDeclaration, - ResolvedMethodLikeDeclaration, - ResolvedReferenceTypeDeclaration, - ResolvedTypeParameterDeclaration + ResolvedFieldDeclaration, + ResolvedMethodDeclaration, + ResolvedMethodLikeDeclaration, + ResolvedReferenceTypeDeclaration, + ResolvedTypeParameterDeclaration } import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap -import com.github.javaparser.resolution.types.{ResolvedReferenceType, ResolvedType, ResolvedTypeVariable} +import com.github.javaparser.resolution.types.{ + ResolvedReferenceType, + ResolvedType, + ResolvedTypeVariable +} import com.github.javaparser.symbolsolver.JavaSymbolSolver -import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.{ObjectMethodSignatures, TypeConstants} +import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.{ + ObjectMethodSignatures, + TypeConstants +} import io.appthreat.javasrc2cpg.util.BindingTable.createBindingTable import io.appthreat.x2cpg.utils.NodeBuilders.{ - newAnnotationLiteralNode, - newBindingNode, - newCallNode, - newClosureBindingNode, - newFieldIdentifierNode, - newIdentifierNode, - newMethodReturnNode, - newModifierNode, - newOperatorCallNode + newAnnotationLiteralNode, + newBindingNode, + newCallNode, + newClosureBindingNode, + newFieldIdentifierNode, + newIdentifierNode, + newMethodReturnNode, + newModifierNode, + newOperatorCallNode } import io.appthreat.javasrc2cpg.scope.Scope.* import io.appthreat.javasrc2cpg.util.{ - BindingTable, - BindingTableAdapterForJavaparser, - BindingTableAdapterForLambdas, - BindingTableEntry, - LambdaBindingInfo, - NameConstants + BindingTable, + BindingTableAdapterForJavaparser, + BindingTableAdapterForLambdas, + BindingTableEntry, + LambdaBindingInfo, + NameConstants +} +import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.{ + ObjectMethodSignatures, + TypeConstants } -import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.{ObjectMethodSignatures, TypeConstants} import io.appthreat.javasrc2cpg.util.Util.{ - composeMethodFullName, - composeMethodLikeSignature, - composeUnresolvedSignature + composeMethodFullName, + composeMethodLikeSignature, + composeUnresolvedSignature } import io.appthreat.x2cpg.Defines.* import io.shiftleft.codepropertygraph.generated.{ - ControlStructureTypes, - DispatchTypes, - EdgeTypes, - EvaluationStrategies, - ModifierTypes, - NodeTypes, - Operators + ControlStructureTypes, + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + ModifierTypes, + NodeTypes, + Operators } import io.shiftleft.codepropertygraph.generated.nodes.{ - NewAnnotation, - NewArrayInitializer, - NewBlock, - NewCall, - NewClosureBinding, - NewControlStructure, - NewFieldIdentifier, - NewIdentifier, - NewImport, - NewJumpTarget, - NewLiteral, - NewLocal, - NewMember, - NewMethod, - NewMethodParameterIn, - NewMethodRef, - NewMethodReturn, - NewModifier, - NewNamespaceBlock, - NewNode, - NewReturn, - NewTypeDecl, - NewTypeRef + NewAnnotation, + NewArrayInitializer, + NewBlock, + NewCall, + NewClosureBinding, + NewControlStructure, + NewFieldIdentifier, + NewIdentifier, + NewImport, + NewJumpTarget, + NewLiteral, + NewLocal, + NewMember, + NewMethod, + NewMethodParameterIn, + NewMethodRef, + NewMethodReturn, + NewModifier, + NewNamespaceBlock, + NewNode, + NewReturn, + NewTypeDecl, + NewTypeRef } import io.appthreat.x2cpg.{Ast, AstCreatorBase, Defines, ValidationMode} import io.appthreat.x2cpg.datastructures.Global @@ -173,12 +183,12 @@ import io.appthreat.javasrc2cpg.scope.{NodeTypeInfo, Scope} import io.appthreat.javasrc2cpg.scope.Scope.ScopeVariable import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator import io.appthreat.javasrc2cpg.util.{ - BindingTable, - BindingTableAdapterForJavaparser, - BindingTableAdapterForLambdas, - BindingTableEntry, - LambdaBindingInfo, - NameConstants + BindingTable, + BindingTableAdapterForJavaparser, + BindingTableAdapterForLambdas, + BindingTableEntry, + LambdaBindingInfo, + NameConstants } import io.appthreat.x2cpg.Defines.StaticInitMethodName import io.shiftleft.codepropertygraph.generated.nodes.NewTypeParameter @@ -193,22 +203,19 @@ case class LambdaImplementedInfo( case class PartialConstructor(initNode: NewCall, initArgs: Seq[Ast], blockAst: Ast) case class ExpectedType(fullName: Option[String], resolvedType: Option[ResolvedType] = None) -object ExpectedType { - def empty: ExpectedType = ExpectedType(None, None) - val Int: ExpectedType = ExpectedType(Some(TypeConstants.Int)) - val Boolean: ExpectedType = ExpectedType(Some(TypeConstants.Boolean)) - val Void: ExpectedType = ExpectedType(Some(TypeConstants.Void)) -} +object ExpectedType: + def empty: ExpectedType = ExpectedType(None, None) + val Int: ExpectedType = ExpectedType(Some(TypeConstants.Int)) + val Boolean: ExpectedType = ExpectedType(Some(TypeConstants.Boolean)) + val Void: ExpectedType = ExpectedType(Some(TypeConstants.Void)) case class AstWithStaticInit(ast: Seq[Ast], staticInits: Seq[Ast]) -object AstWithStaticInit { - val empty: AstWithStaticInit = AstWithStaticInit(Seq.empty, Seq.empty) +object AstWithStaticInit: + val empty: AstWithStaticInit = AstWithStaticInit(Seq.empty, Seq.empty) - def apply(ast: Ast): AstWithStaticInit = { - AstWithStaticInit(Seq(ast), staticInits = Seq.empty) - } -} + def apply(ast: Ast): AstWithStaticInit = + AstWithStaticInit(Seq(ast), staticInits = Seq.empty) /** Translate a Java Parser AST into a CPG AST */ @@ -220,2993 +227,3182 @@ class AstCreator( packagesJarMappings: mutable.Map[String, mutable.Set[String]] )(implicit withSchemaValidation: ValidationMode) extends AstCreatorBase(filename) - with AstNodeBuilder[Node, AstCreator] { - - private val logger = LoggerFactory.getLogger(this.getClass) - - private val scope = Scope() - - private val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver) - private val partialConstructorQueue: mutable.ArrayBuffer[PartialConstructor] = mutable.ArrayBuffer.empty - private val bindingTableCache = mutable.HashMap.empty[String, BindingTable] - - // TODO: Perhaps move this to a NameProvider or some such? Look at kt2cpg to see if some unified representation - // makes sense. - private val LambdaNamePrefix = "lambda$" - private val lambdaKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - private val IndexNamePrefix = "$idx" - private val indexKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - private val IterableNamePrefix = "$iterLocal" - private val iterableKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - - /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a DiffGraph containing the - * corresponding CPG AST. - */ - def createAst(): DiffGraphBuilder = { - val ast = astForTranslationUnit(javaParserAst) - storeInDiffGraph(ast) - diffGraph - } - - /** Copy nodes/edges of given `AST` into the diff graph - */ - def storeInDiffGraph(ast: Ast): Unit = { - Ast.storeInDiffGraph(ast, diffGraph) - } - - protected def line(node: Node): Option[Integer] = node.getBegin.map(x => Integer.valueOf(x.line)).toScala - protected def column(node: Node): Option[Integer] = node.getBegin.map(x => Integer.valueOf(x.column)).toScala - protected def lineEnd(node: Node): Option[Integer] = node.getEnd.map(x => Integer.valueOf(x.line)).toScala - protected def columnEnd(node: Node): Option[Integer] = node.getEnd.map(x => Integer.valueOf(x.line)).toScala - - // TODO: Handle static imports correctly. - private def addImportsToScope(compilationUnit: CompilationUnit): Seq[NewImport] = { - val (asteriskImports, specificImports) = compilationUnit.getImports.asScala.toList.partition(_.isAsterisk) - val specificImportNodes = specificImports.map { importStmt => - val name = importStmt.getName.getIdentifier - val typeFullName = importStmt.getNameAsString // fully qualified name - typeInfoCalc.registerType(typeFullName) - val importNode = NewImport() - .importedAs(name) - .importedEntity(typeFullName) - - if (importStmt.isStatic) { - scope.addStaticImport(importNode) - } else { - scope.addType(name, typeFullName) - } - importNode - } - - val asteriskImportNodes = asteriskImports match { - case imp :: Nil => - val name = NameConstants.WildcardImportName - val typeFullName = imp.getNameAsString - val importNode = NewImport() - .importedAs(name) - .importedEntity(typeFullName) - .isWildcard(true) - scope.addWildcardImport(typeFullName) - Seq(importNode) - case _ => // Only try to guess a wildcard import if exactly one is defined - Seq.empty - } - specificImportNodes ++ asteriskImportNodes - } - - /** Translate compilation unit into AST - */ - private def astForTranslationUnit(compilationUnit: CompilationUnit): Ast = { - - try { - val namespaceBlock = namespaceBlockForPackageDecl(compilationUnit.getPackageDeclaration.toScala) - - scope.pushNamespaceScope(namespaceBlock) - - val importNodes = addImportsToScope(compilationUnit).map(Ast(_)) - - val typeDeclAsts = compilationUnit.getTypes.asScala.map { typ => - astForTypeDecl(typ, astParentType = NodeTypes.NAMESPACE_BLOCK, astParentFullName = namespaceBlock.fullName) - } - - // TODO: Add ASTs - scope.popScope() - Ast(namespaceBlock).withChildren(typeDeclAsts).withChildren(importNodes) - } catch { - case t: UnsolvedSymbolException => - logger.error(s"Unsolved symbol exception caught in $filename") - Ast() - case t: Throwable => - logger.debug(s"Parsing file $filename failed with $t") - logger.debug(s"Caused by ${t.getCause}") - Ast() - } - } - - /** Translate package declaration into AST consisting of a corresponding namespace block. - */ - private def namespaceBlockForPackageDecl(packageDecl: Option[PackageDeclaration]): NewNamespaceBlock = { - packageDecl match { - case Some(decl) => - val packageName = decl.getName.toString - val fullName = s"$filename:$packageName" - NewNamespaceBlock() - .name(packageName) - .fullName(fullName) - .filename(filename) - case None => - globalNamespaceBlock() - } - } - - private def tryWithSafeStackOverflow[T](expr: => T): Try[T] = { - try { - Try(expr) - } catch { - // This is really, really ugly, but there's a bug in the JavaParser symbol solver that can lead to - // unterminated recursion in some cases where types cannot be resolved. - case e: StackOverflowError => - logger.debug(s"Caught StackOverflowError in $filename") - Failure(e) - } - } - - private def composeSignature( - maybeReturnType: Option[String], - maybeParameterTypes: Option[List[String]], - parameterCount: Int - ): String = { - (maybeReturnType, maybeParameterTypes) match { - case (Some(returnType), Some(parameterTypes)) => - composeMethodLikeSignature(returnType, parameterTypes) - - case _ => - composeUnresolvedSignature(parameterCount) - } - } - - private def methodSignature(method: ResolvedMethodDeclaration, typeParamValues: ResolvedTypeParametersMap): String = { - val maybeParameterTypes = calcParameterTypes(method, typeParamValues) - - val maybeReturnType = - Try(method.getReturnType).toOption - .flatMap(returnType => typeInfoCalc.fullName(returnType, typeParamValues)) - - composeSignature(maybeReturnType, maybeParameterTypes, method.getNumberOfParams) - } - - private def toOptionList[T](items: collection.Seq[Option[T]]): Option[List[T]] = { - items.foldLeft[Option[List[T]]](Some(Nil)) { - case (Some(acc), Some(value)) => Some(acc :+ value) - case _ => None - } - } - - private def calcParameterTypes( - methodLike: ResolvedMethodLikeDeclaration, - typeParamValues: ResolvedTypeParametersMap - ): Option[List[String]] = { - val parameterTypes = - Range(0, methodLike.getNumberOfParams) - .flatMap { index => - Try(methodLike.getParam(index)).toOption - } - .map { param => - Try(param.getType).toOption - .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) + with AstNodeBuilder[Node, AstCreator]: + + private val logger = LoggerFactory.getLogger(this.getClass) + + private val scope = Scope() + + private val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver) + private val partialConstructorQueue: mutable.ArrayBuffer[PartialConstructor] = + mutable.ArrayBuffer.empty + private val bindingTableCache = mutable.HashMap.empty[String, BindingTable] + + // TODO: Perhaps move this to a NameProvider or some such? Look at kt2cpg to see if some unified representation + // makes sense. + private val LambdaNamePrefix = "lambda$" + private val lambdaKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + private val IndexNamePrefix = "$idx" + private val indexKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + private val IterableNamePrefix = "$iterLocal" + private val iterableKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + + /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a + * DiffGraph containing the corresponding CPG AST. + */ + def createAst(): DiffGraphBuilder = + val ast = astForTranslationUnit(javaParserAst) + storeInDiffGraph(ast) + diffGraph + + /** Copy nodes/edges of given `AST` into the diff graph + */ + def storeInDiffGraph(ast: Ast): Unit = + Ast.storeInDiffGraph(ast, diffGraph) + + protected def line(node: Node): Option[Integer] = + node.getBegin.map(x => Integer.valueOf(x.line)).toScala + protected def column(node: Node): Option[Integer] = + node.getBegin.map(x => Integer.valueOf(x.column)).toScala + protected def lineEnd(node: Node): Option[Integer] = + node.getEnd.map(x => Integer.valueOf(x.line)).toScala + protected def columnEnd(node: Node): Option[Integer] = + node.getEnd.map(x => Integer.valueOf(x.line)).toScala + + // TODO: Handle static imports correctly. + private def addImportsToScope(compilationUnit: CompilationUnit): Seq[NewImport] = + val (asteriskImports, specificImports) = + compilationUnit.getImports.asScala.toList.partition(_.isAsterisk) + val specificImportNodes = specificImports.map { importStmt => + val name = importStmt.getName.getIdentifier + val typeFullName = importStmt.getNameAsString // fully qualified name + typeInfoCalc.registerType(typeFullName) + val importNode = NewImport() + .importedAs(name) + .importedEntity(typeFullName) + + if importStmt.isStatic then + scope.addStaticImport(importNode) + else + scope.addType(name, typeFullName) + importNode } - toOptionList(parameterTypes) - } - - def getBindingTable(typeDecl: ResolvedReferenceTypeDeclaration): BindingTable = { - val fullName = typeInfoCalc.fullName(typeDecl).getOrElse { - val qualifiedName = typeDecl.getQualifiedName - logger.debug(s"Could not get full name for resolved type decl $qualifiedName. THIS SHOULD NOT HAPPEN!") - qualifiedName - } - bindingTableCache.getOrElseUpdate( - fullName, - createBindingTable(fullName, typeDecl, getBindingTable, new BindingTableAdapterForJavaparser(methodSignature)) - ) - } - - private def getLambdaBindingTable(lambdaBindingInfo: LambdaBindingInfo): BindingTable = { - val fullName = lambdaBindingInfo.fullName - - bindingTableCache.getOrElseUpdate( - fullName, - createBindingTable( - fullName, - lambdaBindingInfo, - getBindingTable, - new BindingTableAdapterForLambdas(methodSignature) - ) - ) - } - - private def createBindingNodes(typeDeclNode: NewTypeDecl, bindingTable: BindingTable): Unit = { - // We only sort to get stable output. - val sortedEntries = - bindingTable.getEntries.toBuffer.sortBy((entry: BindingTableEntry) => s"${entry.name}${entry.signature}") - - sortedEntries.foreach { entry => - val bindingNode = newBindingNode(entry.name, entry.signature, entry.implementingMethodFullName) - - diffGraph.addNode(bindingNode) - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - } - } - - private def astForTypeDeclMember(member: BodyDeclaration[_], astParentFullName: String): AstWithStaticInit = { - member match { - case constructor: ConstructorDeclaration => - val ast = astForConstructor(constructor) - - AstWithStaticInit(ast) - - case method: MethodDeclaration => - val ast = astForMethod(method) - - AstWithStaticInit(ast) - - case typeDeclaration: TypeDeclaration[_] => - AstWithStaticInit(astForTypeDecl(typeDeclaration, NodeTypes.TYPE_DECL, astParentFullName)) - - case fieldDeclaration: FieldDeclaration => - val memberAsts = fieldDeclaration.getVariables.asScala.toList.map { variable => - astForFieldVariable(variable, fieldDeclaration) + val asteriskImportNodes = asteriskImports match + case imp :: Nil => + val name = NameConstants.WildcardImportName + val typeFullName = imp.getNameAsString + val importNode = NewImport() + .importedAs(name) + .importedEntity(typeFullName) + .isWildcard(true) + scope.addWildcardImport(typeFullName) + Seq(importNode) + case _ => // Only try to guess a wildcard import if exactly one is defined + Seq.empty + specificImportNodes ++ asteriskImportNodes + end addImportsToScope + + /** Translate compilation unit into AST + */ + private def astForTranslationUnit(compilationUnit: CompilationUnit): Ast = + try + val namespaceBlock = + namespaceBlockForPackageDecl(compilationUnit.getPackageDeclaration.toScala) + + scope.pushNamespaceScope(namespaceBlock) + + val importNodes = addImportsToScope(compilationUnit).map(Ast(_)) + + val typeDeclAsts = compilationUnit.getTypes.asScala.map { typ => + astForTypeDecl( + typ, + astParentType = NodeTypes.NAMESPACE_BLOCK, + astParentFullName = namespaceBlock.fullName + ) + } + + // TODO: Add ASTs + scope.popScope() + Ast(namespaceBlock).withChildren(typeDeclAsts).withChildren(importNodes) + catch + case t: UnsolvedSymbolException => + logger.error(s"Unsolved symbol exception caught in $filename") + Ast() + case t: Throwable => + logger.debug(s"Parsing file $filename failed with $t") + logger.debug(s"Caused by ${t.getCause}") + Ast() + + /** Translate package declaration into AST consisting of a corresponding namespace block. + */ + private def namespaceBlockForPackageDecl(packageDecl: Option[PackageDeclaration]) + : NewNamespaceBlock = + packageDecl match + case Some(decl) => + val packageName = decl.getName.toString + val fullName = s"$filename:$packageName" + NewNamespaceBlock() + .name(packageName) + .fullName(fullName) + .filename(filename) + case None => + globalNamespaceBlock() + + private def tryWithSafeStackOverflow[T](expr: => T): Try[T] = + try + Try(expr) + catch + // This is really, really ugly, but there's a bug in the JavaParser symbol solver that can lead to + // unterminated recursion in some cases where types cannot be resolved. + case e: StackOverflowError => + logger.debug(s"Caught StackOverflowError in $filename") + Failure(e) + + private def composeSignature( + maybeReturnType: Option[String], + maybeParameterTypes: Option[List[String]], + parameterCount: Int + ): String = + (maybeReturnType, maybeParameterTypes) match + case (Some(returnType), Some(parameterTypes)) => + composeMethodLikeSignature(returnType, parameterTypes) + + case _ => + composeUnresolvedSignature(parameterCount) + + private def methodSignature( + method: ResolvedMethodDeclaration, + typeParamValues: ResolvedTypeParametersMap + ): String = + val maybeParameterTypes = calcParameterTypes(method, typeParamValues) + + val maybeReturnType = + Try(method.getReturnType).toOption + .flatMap(returnType => typeInfoCalc.fullName(returnType, typeParamValues)) + + composeSignature(maybeReturnType, maybeParameterTypes, method.getNumberOfParams) + + private def toOptionList[T](items: collection.Seq[Option[T]]): Option[List[T]] = + items.foldLeft[Option[List[T]]](Some(Nil)) { + case (Some(acc), Some(value)) => Some(acc :+ value) + case _ => None } - val assignments = assignmentsForVarDecl( - fieldDeclaration.getVariables.asScala.toList, - line(fieldDeclaration), - column(fieldDeclaration) - ) + private def calcParameterTypes( + methodLike: ResolvedMethodLikeDeclaration, + typeParamValues: ResolvedTypeParametersMap + ): Option[List[String]] = + val parameterTypes = + Range(0, methodLike.getNumberOfParams) + .flatMap { index => + Try(methodLike.getParam(index)).toOption + } + .map { param => + Try(param.getType).toOption + .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) + } - val staticInitAsts = if (fieldDeclaration.isStatic) assignments else Nil - if (!fieldDeclaration.isStatic) scope.addMemberInitializers(assignments) - - AstWithStaticInit(memberAsts, staticInitAsts) - - case initDeclaration: InitializerDeclaration => - val stmts = initDeclaration.getBody.getStatements - val asts = stmts.asScala.flatMap(astsForStatement).toList - AstWithStaticInit(ast = Seq.empty, staticInits = asts) - - case unhandled => - // AnnotationMemberDeclarations and InitializerDeclarations as children of typeDecls are the - // expected cases. - logger.debug(s"Found unhandled typeDecl member ${unhandled.getClass} in file $filename") - AstWithStaticInit.empty - } - } - - private def identifierForResolvedTypeParameter(typeParameter: ResolvedTypeParameterDeclaration): NewIdentifier = { - val name = typeParameter.getName - val typeFullName = Try(typeParameter.getUpperBound).toOption - .flatMap(typeInfoCalc.fullName) - .getOrElse(TypeConstants.Object) - typeInfoCalc.registerType(typeFullName) - newIdentifierNode(name, typeFullName) - } - - private def clinitAstFromStaticInits(staticInits: Seq[Ast]): Option[Ast] = { - Option.when(staticInits.nonEmpty) { - val signature = composeMethodLikeSignature(TypeConstants.Void, Nil) - val enclosingDeclName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val fullName = composeMethodFullName(enclosingDeclName, Defines.StaticInitMethodName, signature) - staticInitMethodAst(staticInits.toList, fullName, Some(signature), TypeConstants.Void) - } - } - - private def codeForTypeDecl(typ: TypeDeclaration[_], isInterface: Boolean): String = { - val codeBuilder = new mutable.StringBuilder() - if (typ.isPublic) { - codeBuilder.append("public ") - } else if (typ.isPrivate) { - codeBuilder.append("private ") - } else if (typ.isProtected) { - codeBuilder.append("protected ") - } - - if (typ.isStatic) { - codeBuilder.append("static ") - } - - val classPrefix = - if (isInterface) - "interface " - else if (typ.isEnumDeclaration) - "enum " - else - "class " - codeBuilder.append(classPrefix) - codeBuilder.append(typ.getNameAsString) - - codeBuilder.toString() - } - - private def modifiersForTypeDecl(typ: TypeDeclaration[_], isInterface: Boolean): List[NewModifier] = { - val accessModifierType = if (typ.isPublic) { - Some(ModifierTypes.PUBLIC) - } else if (typ.isPrivate) { - Some(ModifierTypes.PRIVATE) - } else if (typ.isProtected) { - Some(ModifierTypes.PROTECTED) - } else { - None - } - val accessModifier = accessModifierType.map(newModifierNode) - - val abstractModifier = - Option.when(isInterface || typ.getMethods.asScala.exists(_.isAbstract))(newModifierNode(ModifierTypes.ABSTRACT)) - - List(accessModifier, abstractModifier).flatten - } - - private def createTypeDeclNode( - typ: TypeDeclaration[_], - astParentType: String, - astParentFullName: String, - isInterface: Boolean - ): NewTypeDecl = { - val baseTypeFullNames = if (typ.isClassOrInterfaceDeclaration) { - val decl = typ.asClassOrInterfaceDeclaration() - val extendedTypes = decl.getExtendedTypes.asScala - val implementedTypes = decl.getImplementedTypes.asScala - val inheritsFromTypeNames = - (extendedTypes ++ implementedTypes).flatMap { typ => - typeInfoCalc.fullName(typ).orElse(scope.lookupType(typ.getNameAsString)) + toOptionList(parameterTypes) + + def getBindingTable(typeDecl: ResolvedReferenceTypeDeclaration): BindingTable = + val fullName = typeInfoCalc.fullName(typeDecl).getOrElse { + val qualifiedName = typeDecl.getQualifiedName + logger.debug( + s"Could not get full name for resolved type decl $qualifiedName. THIS SHOULD NOT HAPPEN!" + ) + qualifiedName } - val maybeJavaObjectType = if (extendedTypes.isEmpty) { - typeInfoCalc.registerType(TypeConstants.Object) - Seq(TypeConstants.Object) - } else { - Seq() - } - maybeJavaObjectType ++ inheritsFromTypeNames - } else { - List.empty[String] - } - - val resolvedType = tryWithSafeStackOverflow(typ.resolve()).toOption - val defaultFullName = s"${Defines.UnresolvedNamespace}.${typ.getNameAsString}" - val name = resolvedType.flatMap(typeInfoCalc.name).getOrElse(typ.getNameAsString) - val typeFullName = resolvedType.flatMap(typeInfoCalc.fullName).getOrElse(defaultFullName) - val code = codeForTypeDecl(typ, isInterface) - val typeDecl = NewTypeDecl() - .name(name) - .fullName(typeFullName) - .lineNumber(line(typ)) - .columnNumber(column(typ)) - .inheritsFromTypeFullName(baseTypeFullNames) - .filename(filename) - .code(code) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - if (packagesJarMappings.contains(typeFullName)) - typeDecl.aliasTypeFullName(packagesJarMappings.getOrElse(typeFullName, mutable.Set.empty).headOption) - typeDecl - } - - private def addTypeDeclTypeParamsToScope(typ: TypeDeclaration[_]): Unit = { - tryWithSafeStackOverflow(typ.resolve()).map(_.getTypeParameters.asScala) match { - case Success(resolvedTypeParams) => - resolvedTypeParams - .map(identifierForResolvedTypeParameter) - .foreach { typeParamIdentifier => - scope.addType(typeParamIdentifier.name, typeParamIdentifier.typeFullName) - } - - case _ => // Nothing to do here - } - } - - private def astForTypeDecl(typ: TypeDeclaration[_], astParentType: String, astParentFullName: String): Ast = { - val isInterface = typ match { - case classDeclaration: ClassOrInterfaceDeclaration => classDeclaration.isInterface - case _ => false - } - - val typeDeclNode = createTypeDeclNode(typ, astParentType, astParentFullName, isInterface) - - scope.pushTypeDeclScope(typeDeclNode) - addTypeDeclTypeParamsToScope(typ) - - val enumEntryAsts = if (typ.isEnumDeclaration) { - typ.asEnumDeclaration().getEntries.asScala.map(astForEnumEntry).toList - } else { - List.empty - } - - val staticInits: mutable.Buffer[Ast] = mutable.Buffer() - val memberAsts = typ.getMembers.asScala.flatMap { member => - val astWithInits = - astForTypeDeclMember(member, astParentFullName = NodeTypes.TYPE_DECL) - staticInits.appendAll(astWithInits.staticInits) - astWithInits.ast - } - - val defaultConstructorAst = if (!isInterface && typ.getConstructors.isEmpty) { - Some(astForDefaultConstructor()) - } else { - None - } - - val annotationAsts = typ.getAnnotations.asScala.map(astForAnnotationExpr) - - val clinitAst = clinitAstFromStaticInits(staticInits.toSeq) - - val localDecls = scope.localDeclsInScope - val lambdaMethods = scope.lambdaMethodsInScope - - val modifiers = modifiersForTypeDecl(typ, isInterface) - - val typeDeclAst = Ast(typeDeclNode) - .withChildren(enumEntryAsts) - .withChildren(memberAsts) - .withChildren(defaultConstructorAst.toList) - .withChildren(annotationAsts) - .withChildren(clinitAst.toSeq) - .withChildren(localDecls) - .withChildren(lambdaMethods) - .withChildren(modifiers.map(Ast(_))) - - val defaultConstructorBindingEntry = - defaultConstructorAst - .flatMap(_.root) - .collect { case defaultConstructor: NewMethod => - BindingTableEntry( - io.appthreat.x2cpg.Defines.ConstructorMethodName, - defaultConstructor.signature, - defaultConstructor.fullName + bindingTableCache.getOrElseUpdate( + fullName, + createBindingTable( + fullName, + typeDecl, + getBindingTable, + new BindingTableAdapterForJavaparser(methodSignature) + ) + ) + + private def getLambdaBindingTable(lambdaBindingInfo: LambdaBindingInfo): BindingTable = + val fullName = lambdaBindingInfo.fullName + + bindingTableCache.getOrElseUpdate( + fullName, + createBindingTable( + fullName, + lambdaBindingInfo, + getBindingTable, + new BindingTableAdapterForLambdas(methodSignature) ) + ) + + private def createBindingNodes(typeDeclNode: NewTypeDecl, bindingTable: BindingTable): Unit = + // We only sort to get stable output. + val sortedEntries = + bindingTable.getEntries.toBuffer.sortBy((entry: BindingTableEntry) => + s"${entry.name}${entry.signature}" + ) + + sortedEntries.foreach { entry => + val bindingNode = + newBindingNode(entry.name, entry.signature, entry.implementingMethodFullName) + + diffGraph.addNode(bindingNode) + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) } - // Annotation declarations need no binding table as objects of this - // typ never get called from user code. - // Furthermore the parser library throws an exception when trying to - // access e.g. the declared methods of an annotation declaration. - if (!typ.isInstanceOf[AnnotationDeclaration]) { - tryWithSafeStackOverflow(typ.resolve()).toOption.foreach { resolvedTypeDecl => - val bindingTable = getBindingTable(resolvedTypeDecl) - defaultConstructorBindingEntry.foreach(bindingTable.add) - createBindingNodes(typeDeclNode, bindingTable) - } - } - - scope.popScope() - - typeDeclAst - } - - private def astForDefaultConstructor(): Ast = { - val typeFullName = scope.enclosingTypeDeclFullName - val signature = s"${TypeConstants.Void}()" - val fullName = composeMethodFullName( - typeFullName.getOrElse(Defines.UnresolvedNamespace), - Defines.ConstructorMethodName, - signature - ) - val constructorNode = NewMethod() - .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) - .fullName(fullName) - .signature(signature) - .filename(filename) - .isExternal(false) - - val thisAst = Ast(thisNodeForMethod(typeFullName, lineNumber = None)) - val bodyAst = Ast(NewBlock()).withChildren(scope.memberInitializers) - - val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None) - - val modifiers = List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - - methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) - } - - private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = { - // TODO Fix enum entries in general - val typeFullName = - tryWithSafeStackOverflow(entry.resolve().getType).toOption.flatMap(typeInfoCalc.fullName) - - val entryNode = memberNode(entry, entry.getNameAsString, entry.toString, typeFullName.getOrElse("ANY")) - - val name = s"${typeFullName.getOrElse(Defines.UnresolvedNamespace)}.${Defines.ConstructorMethodName}" - - Ast(entryNode) - } - - private def modifiersForFieldDeclaration(decl: FieldDeclaration): Seq[Ast] = { - val staticModifier = - Option.when(decl.isStatic)(newModifierNode(ModifierTypes.STATIC)) - - val accessModifierType = - if (decl.isPublic) - Some(ModifierTypes.PUBLIC) - else if (decl.isPrivate) - Some(ModifierTypes.PRIVATE) - else if (decl.isProtected) - Some(ModifierTypes.PROTECTED) - else - None - - val accessModifier = accessModifierType.map(newModifierNode) - - List(staticModifier, accessModifier).flatten.map(Ast(_)) - } - - private def astForFieldVariable(v: VariableDeclarator, fieldDeclaration: FieldDeclaration): Ast = { - // TODO: Should be able to find expected type here - val annotations = fieldDeclaration.getAnnotations - - // variable can be declared with generic type, so we need to get rid of the <> part of it to get the package information - // and append the <> when forming the typeFullName again - // Ex - private Consumer consumer; - // From Consumer we need to get to Consumer so splitting it by '<' and then combining with '<' to - // form typeFullName as Consumer - - val typeFullNameWithoutGenericSplit = typeInfoCalc - .fullName(v.getType) - .orElse(scope.lookupType(v.getTypeAsString)) - .getOrElse(s"${Defines.UnresolvedNamespace}.${v.getTypeAsString}") - val typeFullName = { - // Check if the typeFullName is unresolved and if it has generic information to resolve the typeFullName - if ( - typeFullNameWithoutGenericSplit - .contains(Defines.UnresolvedNamespace) && v.getTypeAsString.contains(Defines.LeftAngularBracket) - ) { - val splitByLeftAngular = v.getTypeAsString.split(Defines.LeftAngularBracket) - scope.lookupType(splitByLeftAngular.head) match { - case Some(fullName) => - fullName + splitByLeftAngular - .slice(1, splitByLeftAngular.size) - .mkString(Defines.LeftAngularBracket, Defines.LeftAngularBracket, "") - case None => typeFullNameWithoutGenericSplit + private def astForTypeDeclMember( + member: BodyDeclaration[?], + astParentFullName: String + ): AstWithStaticInit = + member match + case constructor: ConstructorDeclaration => + val ast = astForConstructor(constructor) + + AstWithStaticInit(ast) + + case method: MethodDeclaration => + val ast = astForMethod(method) + + AstWithStaticInit(ast) + + case typeDeclaration: TypeDeclaration[?] => + AstWithStaticInit(astForTypeDecl( + typeDeclaration, + NodeTypes.TYPE_DECL, + astParentFullName + )) + + case fieldDeclaration: FieldDeclaration => + val memberAsts = fieldDeclaration.getVariables.asScala.toList.map { variable => + astForFieldVariable(variable, fieldDeclaration) + } + + val assignments = assignmentsForVarDecl( + fieldDeclaration.getVariables.asScala.toList, + line(fieldDeclaration), + column(fieldDeclaration) + ) + + val staticInitAsts = if fieldDeclaration.isStatic then assignments else Nil + if !fieldDeclaration.isStatic then scope.addMemberInitializers(assignments) + + AstWithStaticInit(memberAsts, staticInitAsts) + + case initDeclaration: InitializerDeclaration => + val stmts = initDeclaration.getBody.getStatements + val asts = stmts.asScala.flatMap(astsForStatement).toList + AstWithStaticInit(ast = Seq.empty, staticInits = asts) + + case unhandled => + // AnnotationMemberDeclarations and InitializerDeclarations as children of typeDecls are the + // expected cases. + logger.debug( + s"Found unhandled typeDecl member ${unhandled.getClass} in file $filename" + ) + AstWithStaticInit.empty + + private def identifierForResolvedTypeParameter(typeParameter: ResolvedTypeParameterDeclaration) + : NewIdentifier = + val name = typeParameter.getName + val typeFullName = Try(typeParameter.getUpperBound).toOption + .flatMap(typeInfoCalc.fullName) + .getOrElse(TypeConstants.Object) + typeInfoCalc.registerType(typeFullName) + newIdentifierNode(name, typeFullName) + + private def clinitAstFromStaticInits(staticInits: Seq[Ast]): Option[Ast] = + Option.when(staticInits.nonEmpty) { + val signature = composeMethodLikeSignature(TypeConstants.Void, Nil) + val enclosingDeclName = + scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val fullName = + composeMethodFullName(enclosingDeclName, Defines.StaticInitMethodName, signature) + staticInitMethodAst(staticInits.toList, fullName, Some(signature), TypeConstants.Void) } - } else typeFullNameWithoutGenericSplit - } - val name = v.getName.toString - val node = memberNode(v, name, s"$typeFullName $name", typeFullName) - val memberAst = Ast(node) - val annotationAsts = annotations.asScala.map(astForAnnotationExpr) - - val fieldDeclModifiers = modifiersForFieldDeclaration(fieldDeclaration) - - scope.addMember(node, fieldDeclaration.isStatic) - - memberAst - .withChildren(annotationAsts) - .withChildren(fieldDeclModifiers) - } - - private def astForConstructor(constructorDeclaration: ConstructorDeclaration): Ast = { - val constructorNode = createPartialMethod(constructorDeclaration) - .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) - - scope.pushMethodScope(constructorNode, ExpectedType.Void) - val maybeResolved = tryWithSafeStackOverflow(constructorDeclaration.resolve()) - - val parameterAsts = astsForParameterList(constructorDeclaration.getParameters).toList - val paramTypes = argumentTypesForMethodLike(maybeResolved) - val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) - val typeFullName = scope.enclosingTypeDeclFullName - val fullName = - composeMethodFullName( - typeFullName.getOrElse(Defines.UnresolvedNamespace), - Defines.ConstructorMethodName, - signature - ) - val typeNameLookup = fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - constructorNode - .fullName(fullName) - .signature(signature) - if (packagesJarMappings.contains(typeNameLookup)) - constructorNode.astParentType(packagesJarMappings.getOrElse(typeNameLookup, mutable.Set.empty).head) - - parameterAsts.foreach { ast => - ast.root match { - case Some(parameter: NewMethodParameterIn) => scope.addParameter(parameter) - case _ => // This should never happen - } - } - - val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration)) - scope.addParameter(thisNode) - val thisAst = Ast(thisNode) - - val bodyAst = astForConstructorBody(Some(constructorDeclaration.getBody)) - val methodReturn = constructorReturnNode(constructorDeclaration) - - val annotationAsts = constructorDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toList - - val modifiers = - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR) :: modifiersForMethod(constructorDeclaration).filterNot( - _.modifierType == ModifierTypes.VIRTUAL - ) - - scope.popScope() - - methodAstWithAnnotations( - constructorNode, - thisAst :: parameterAsts, - bodyAst, - methodReturn, - modifiers, - annotationAsts - ) - } - - private def thisNodeForMethod( - maybeTypeFullName: Option[String], - lineNumber: Option[Integer] - ): NewMethodParameterIn = { - val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(TypeConstants.Any)) - NodeBuilders.newThisParameterNode( - typeFullName = typeFullName, - dynamicTypeHintFullName = maybeTypeFullName.toSeq, - line = lineNumber - ) - } - - private def convertAnnotationValueExpr(expr: Expression): Option[Ast] = { - expr match { - case arrayInit: ArrayInitializerExpr => - val arrayInitNode = NewArrayInitializer() - .code(arrayInit.toString) - val initElementAsts = arrayInit.getValues.asScala.toList.map { value => - convertAnnotationValueExpr(value) + + private def codeForTypeDecl(typ: TypeDeclaration[?], isInterface: Boolean): String = + val codeBuilder = new mutable.StringBuilder() + if typ.isPublic then + codeBuilder.append("public ") + else if typ.isPrivate then + codeBuilder.append("private ") + else if typ.isProtected then + codeBuilder.append("protected ") + + if typ.isStatic then + codeBuilder.append("static ") + + val classPrefix = + if isInterface then + "interface " + else if typ.isEnumDeclaration then + "enum " + else + "class " + codeBuilder.append(classPrefix) + codeBuilder.append(typ.getNameAsString) + + codeBuilder.toString() + end codeForTypeDecl + + private def modifiersForTypeDecl( + typ: TypeDeclaration[?], + isInterface: Boolean + ): List[NewModifier] = + val accessModifierType = if typ.isPublic then + Some(ModifierTypes.PUBLIC) + else if typ.isPrivate then + Some(ModifierTypes.PRIVATE) + else if typ.isProtected then + Some(ModifierTypes.PROTECTED) + else + None + val accessModifier = accessModifierType.map(newModifierNode) + + val abstractModifier = + Option.when(isInterface || typ.getMethods.asScala.exists(_.isAbstract))(newModifierNode( + ModifierTypes.ABSTRACT + )) + + List(accessModifier, abstractModifier).flatten + end modifiersForTypeDecl + + private def createTypeDeclNode( + typ: TypeDeclaration[?], + astParentType: String, + astParentFullName: String, + isInterface: Boolean + ): NewTypeDecl = + val baseTypeFullNames = if typ.isClassOrInterfaceDeclaration then + val decl = typ.asClassOrInterfaceDeclaration() + val extendedTypes = decl.getExtendedTypes.asScala + val implementedTypes = decl.getImplementedTypes.asScala + val inheritsFromTypeNames = + (extendedTypes ++ implementedTypes).flatMap { typ => + typeInfoCalc.fullName(typ).orElse(scope.lookupType(typ.getNameAsString)) + } + val maybeJavaObjectType = if extendedTypes.isEmpty then + typeInfoCalc.registerType(TypeConstants.Object) + Seq(TypeConstants.Object) + else + Seq() + maybeJavaObjectType ++ inheritsFromTypeNames + else + List.empty[String] + + val resolvedType = tryWithSafeStackOverflow(typ.resolve()).toOption + val defaultFullName = s"${Defines.UnresolvedNamespace}.${typ.getNameAsString}" + val name = resolvedType.flatMap(typeInfoCalc.name).getOrElse(typ.getNameAsString) + val typeFullName = resolvedType.flatMap(typeInfoCalc.fullName).getOrElse(defaultFullName) + val code = codeForTypeDecl(typ, isInterface) + val typeDecl = NewTypeDecl() + .name(name) + .fullName(typeFullName) + .lineNumber(line(typ)) + .columnNumber(column(typ)) + .inheritsFromTypeFullName(baseTypeFullNames) + .filename(filename) + .code(code) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + if packagesJarMappings.contains(typeFullName) then + typeDecl.aliasTypeFullName(packagesJarMappings.getOrElse( + typeFullName, + mutable.Set.empty + ).headOption) + typeDecl + end createTypeDeclNode + + private def addTypeDeclTypeParamsToScope(typ: TypeDeclaration[?]): Unit = + tryWithSafeStackOverflow(typ.resolve()).map(_.getTypeParameters.asScala) match + case Success(resolvedTypeParams) => + resolvedTypeParams + .map(identifierForResolvedTypeParameter) + .foreach { typeParamIdentifier => + scope.addType(typeParamIdentifier.name, typeParamIdentifier.typeFullName) + } + + case _ => // Nothing to do here + private def astForTypeDecl( + typ: TypeDeclaration[?], + astParentType: String, + astParentFullName: String + ): Ast = + val isInterface = typ match + case classDeclaration: ClassOrInterfaceDeclaration => classDeclaration.isInterface + case _ => false + + val typeDeclNode = createTypeDeclNode(typ, astParentType, astParentFullName, isInterface) + + scope.pushTypeDeclScope(typeDeclNode) + addTypeDeclTypeParamsToScope(typ) + + val enumEntryAsts = if typ.isEnumDeclaration then + typ.asEnumDeclaration().getEntries.asScala.map(astForEnumEntry).toList + else + List.empty + + val staticInits: mutable.Buffer[Ast] = mutable.Buffer() + val memberAsts = typ.getMembers.asScala.flatMap { member => + val astWithInits = + astForTypeDeclMember(member, astParentFullName = NodeTypes.TYPE_DECL) + staticInits.appendAll(astWithInits.staticInits) + astWithInits.ast } - setArgumentIndices(initElementAsts.flatten) + val defaultConstructorAst = if !isInterface && typ.getConstructors.isEmpty then + Some(astForDefaultConstructor()) + else + None + + val annotationAsts = typ.getAnnotations.asScala.map(astForAnnotationExpr) + + val clinitAst = clinitAstFromStaticInits(staticInits.toSeq) + + val localDecls = scope.localDeclsInScope + val lambdaMethods = scope.lambdaMethodsInScope + + val modifiers = modifiersForTypeDecl(typ, isInterface) + + val typeDeclAst = Ast(typeDeclNode) + .withChildren(enumEntryAsts) + .withChildren(memberAsts) + .withChildren(defaultConstructorAst.toList) + .withChildren(annotationAsts) + .withChildren(clinitAst.toSeq) + .withChildren(localDecls) + .withChildren(lambdaMethods) + .withChildren(modifiers.map(Ast(_))) + + val defaultConstructorBindingEntry = + defaultConstructorAst + .flatMap(_.root) + .collect { case defaultConstructor: NewMethod => + BindingTableEntry( + io.appthreat.x2cpg.Defines.ConstructorMethodName, + defaultConstructor.signature, + defaultConstructor.fullName + ) + } + + // Annotation declarations need no binding table as objects of this + // typ never get called from user code. + // Furthermore the parser library throws an exception when trying to + // access e.g. the declared methods of an annotation declaration. + if !typ.isInstanceOf[AnnotationDeclaration] then + tryWithSafeStackOverflow(typ.resolve()).toOption.foreach { resolvedTypeDecl => + val bindingTable = getBindingTable(resolvedTypeDecl) + defaultConstructorBindingEntry.foreach(bindingTable.add) + createBindingNodes(typeDeclNode, bindingTable) + } - val returnAst = initElementAsts.foldLeft(Ast(arrayInitNode)) { - case (ast, Some(elementAst)) => - ast.withChild(elementAst) - case (ast, _) => ast + scope.popScope() + + typeDeclAst + end astForTypeDecl + + private def astForDefaultConstructor(): Ast = + val typeFullName = scope.enclosingTypeDeclFullName + val signature = s"${TypeConstants.Void}()" + val fullName = composeMethodFullName( + typeFullName.getOrElse(Defines.UnresolvedNamespace), + Defines.ConstructorMethodName, + signature + ) + val constructorNode = NewMethod() + .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) + .fullName(fullName) + .signature(signature) + .filename(filename) + .isExternal(false) + + val thisAst = Ast(thisNodeForMethod(typeFullName, lineNumber = None)) + val bodyAst = Ast(NewBlock()).withChildren(scope.memberInitializers) + + val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None) + + val modifiers = + List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) + + methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) + end astForDefaultConstructor + + private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = + // TODO Fix enum entries in general + val typeFullName = + tryWithSafeStackOverflow(entry.resolve().getType).toOption.flatMap( + typeInfoCalc.fullName + ) + + val entryNode = + memberNode(entry, entry.getNameAsString, entry.toString, typeFullName.getOrElse("ANY")) + + val name = + s"${typeFullName.getOrElse(Defines.UnresolvedNamespace)}.${Defines.ConstructorMethodName}" + + Ast(entryNode) + + private def modifiersForFieldDeclaration(decl: FieldDeclaration): Seq[Ast] = + val staticModifier = + Option.when(decl.isStatic)(newModifierNode(ModifierTypes.STATIC)) + + val accessModifierType = + if decl.isPublic then + Some(ModifierTypes.PUBLIC) + else if decl.isPrivate then + Some(ModifierTypes.PRIVATE) + else if decl.isProtected then + Some(ModifierTypes.PROTECTED) + else + None + + val accessModifier = accessModifierType.map(newModifierNode) + + List(staticModifier, accessModifier).flatten.map(Ast(_)) + + private def astForFieldVariable( + v: VariableDeclarator, + fieldDeclaration: FieldDeclaration + ): Ast = + // TODO: Should be able to find expected type here + val annotations = fieldDeclaration.getAnnotations + + // variable can be declared with generic type, so we need to get rid of the <> part of it to get the package information + // and append the <> when forming the typeFullName again + // Ex - private Consumer consumer; + // From Consumer we need to get to Consumer so splitting it by '<' and then combining with '<' to + // form typeFullName as Consumer + + val typeFullNameWithoutGenericSplit = typeInfoCalc + .fullName(v.getType) + .orElse(scope.lookupType(v.getTypeAsString)) + .getOrElse(s"${Defines.UnresolvedNamespace}.${v.getTypeAsString}") + val typeFullName = + // Check if the typeFullName is unresolved and if it has generic information to resolve the typeFullName + if + typeFullNameWithoutGenericSplit + .contains(Defines.UnresolvedNamespace) && v.getTypeAsString.contains( + Defines.LeftAngularBracket + ) + then + val splitByLeftAngular = v.getTypeAsString.split(Defines.LeftAngularBracket) + scope.lookupType(splitByLeftAngular.head) match + case Some(fullName) => + fullName + splitByLeftAngular + .slice(1, splitByLeftAngular.size) + .mkString(Defines.LeftAngularBracket, Defines.LeftAngularBracket, "") + case None => typeFullNameWithoutGenericSplit + else typeFullNameWithoutGenericSplit + val name = v.getName.toString + val node = memberNode(v, name, s"$typeFullName $name", typeFullName) + val memberAst = Ast(node) + val annotationAsts = annotations.asScala.map(astForAnnotationExpr) + + val fieldDeclModifiers = modifiersForFieldDeclaration(fieldDeclaration) + + scope.addMember(node, fieldDeclaration.isStatic) + + memberAst + .withChildren(annotationAsts) + .withChildren(fieldDeclModifiers) + end astForFieldVariable + + private def astForConstructor(constructorDeclaration: ConstructorDeclaration): Ast = + val constructorNode = createPartialMethod(constructorDeclaration) + .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) + + scope.pushMethodScope(constructorNode, ExpectedType.Void) + val maybeResolved = tryWithSafeStackOverflow(constructorDeclaration.resolve()) + + val parameterAsts = astsForParameterList(constructorDeclaration.getParameters).toList + val paramTypes = argumentTypesForMethodLike(maybeResolved) + val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) + val typeFullName = scope.enclosingTypeDeclFullName + val fullName = + composeMethodFullName( + typeFullName.getOrElse(Defines.UnresolvedNamespace), + Defines.ConstructorMethodName, + signature + ) + val typeNameLookup = fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + constructorNode + .fullName(fullName) + .signature(signature) + if packagesJarMappings.contains(typeNameLookup) then + constructorNode.astParentType(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).head) + + parameterAsts.foreach { ast => + ast.root match + case Some(parameter: NewMethodParameterIn) => scope.addParameter(parameter) + case _ => // This should never happen } - Some(returnAst) - - case annotationExpr: AnnotationExpr => - Some(astForAnnotationExpr(annotationExpr)) - - case literalExpr: LiteralExpr => - Some(astForAnnotationLiteralExpr(literalExpr)) - - case _: ClassExpr => - // TODO: Implement for known case - None - - case _: FieldAccessExpr => - // TODO: Implement for known case - None - - case _: BinaryExpr => - // TODO: Implement for known case - None - - case _: NameExpr => - // TODO: Implement for known case - None - - case _ => - logger.debug(s"convertAnnotationValueExpr not yet implemented for unknown case ${expr.getClass}") - None - } - } - - private def astForAnnotationLiteralExpr(literalExpr: LiteralExpr): Ast = { - val valueNode = - literalExpr match { - case literal: StringLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: IntegerLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: BooleanLiteralExpr => newAnnotationLiteralNode(java.lang.Boolean.toString(literal.getValue)) - case literal: CharLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: DoubleLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: LongLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case _: NullLiteralExpr => newAnnotationLiteralNode("null") - case literal: TextBlockLiteralExpr => newAnnotationLiteralNode(literal.getValue) - } - - Ast(valueNode) - } - - private def exprNameFromStack(expr: Expression): Option[String] = expr match { - case annotation: AnnotationExpr => - scope.lookupType(annotation.getNameAsString) - case namedExpr: NodeWithName[_] => - scope.lookupVariableOrType(namedExpr.getNameAsString) - case namedExpr: NodeWithSimpleName[_] => - scope.lookupVariableOrType(namedExpr.getNameAsString) - // JavaParser doesn't handle literals well for some reason - case _: BooleanLiteralExpr => Some("boolean") - case _: CharLiteralExpr => Some("char") - case _: DoubleLiteralExpr => Some("double") - case _: IntegerLiteralExpr => Some("int") - case _: LongLiteralExpr => Some("long") - case _: NullLiteralExpr => Some("null") - case _: StringLiteralExpr => Some("java.lang.String") - case _: TextBlockLiteralExpr => Some("java.lang.String") - case _ => None - } - - private def expressionReturnTypeFullName(expr: Expression): Option[String] = { - - val resolvedTypeOption = tryWithSafeStackOverflow(expr.calculateResolvedType()) match { - case Failure(ex) => - ex match { - // If ast parser fails to resolve type, try resolving locally by using name - // Precaution when resolving by name, we only want to resolve for case when the expr is solely a MethodCallExpr - // and doesn't have a scope to it - case symbolException: UnsolvedSymbolException => - expr match { - case callExpr: MethodCallExpr => - callExpr.getScope.toScala match { - case Some(_: Expression) => None - case _ => scope.lookupType(symbolException.getName) + + val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration)) + scope.addParameter(thisNode) + val thisAst = Ast(thisNode) + + val bodyAst = astForConstructorBody(Some(constructorDeclaration.getBody)) + val methodReturn = constructorReturnNode(constructorDeclaration) + + val annotationAsts = + constructorDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toList + + val modifiers = + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR) :: modifiersForMethod( + constructorDeclaration + ).filterNot( + _.modifierType == ModifierTypes.VIRTUAL + ) + + scope.popScope() + + methodAstWithAnnotations( + constructorNode, + thisAst :: parameterAsts, + bodyAst, + methodReturn, + modifiers, + annotationAsts + ) + end astForConstructor + + private def thisNodeForMethod( + maybeTypeFullName: Option[String], + lineNumber: Option[Integer] + ): NewMethodParameterIn = + val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(TypeConstants.Any)) + NodeBuilders.newThisParameterNode( + typeFullName = typeFullName, + dynamicTypeHintFullName = maybeTypeFullName.toSeq, + line = lineNumber + ) + + private def convertAnnotationValueExpr(expr: Expression): Option[Ast] = + expr match + case arrayInit: ArrayInitializerExpr => + val arrayInitNode = NewArrayInitializer() + .code(arrayInit.toString) + val initElementAsts = arrayInit.getValues.asScala.toList.map { value => + convertAnnotationValueExpr(value) } - case _ => None - } - case _ => None + + setArgumentIndices(initElementAsts.flatten) + + val returnAst = initElementAsts.foldLeft(Ast(arrayInitNode)) { + case (ast, Some(elementAst)) => + ast.withChild(elementAst) + case (ast, _) => ast + } + Some(returnAst) + + case annotationExpr: AnnotationExpr => + Some(astForAnnotationExpr(annotationExpr)) + + case literalExpr: LiteralExpr => + Some(astForAnnotationLiteralExpr(literalExpr)) + + case _: ClassExpr => + // TODO: Implement for known case + None + + case _: FieldAccessExpr => + // TODO: Implement for known case + None + + case _: BinaryExpr => + // TODO: Implement for known case + None + + case _: NameExpr => + // TODO: Implement for known case + None + + case _ => + logger.debug( + s"convertAnnotationValueExpr not yet implemented for unknown case ${expr.getClass}" + ) + None + + private def astForAnnotationLiteralExpr(literalExpr: LiteralExpr): Ast = + val valueNode = + literalExpr match + case literal: StringLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: IntegerLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: BooleanLiteralExpr => + newAnnotationLiteralNode(java.lang.Boolean.toString(literal.getValue)) + case literal: CharLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: DoubleLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: LongLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case _: NullLiteralExpr => newAnnotationLiteralNode("null") + case literal: TextBlockLiteralExpr => newAnnotationLiteralNode(literal.getValue) + + Ast(valueNode) + + private def exprNameFromStack(expr: Expression): Option[String] = expr match + case annotation: AnnotationExpr => + scope.lookupType(annotation.getNameAsString) + case namedExpr: NodeWithName[?] => + scope.lookupVariableOrType(namedExpr.getNameAsString) + case namedExpr: NodeWithSimpleName[?] => + scope.lookupVariableOrType(namedExpr.getNameAsString) + // JavaParser doesn't handle literals well for some reason + case _: BooleanLiteralExpr => Some("boolean") + case _: CharLiteralExpr => Some("char") + case _: DoubleLiteralExpr => Some("double") + case _: IntegerLiteralExpr => Some("int") + case _: LongLiteralExpr => Some("long") + case _: NullLiteralExpr => Some("null") + case _: StringLiteralExpr => Some("java.lang.String") + case _: TextBlockLiteralExpr => Some("java.lang.String") + case _ => None + + private def expressionReturnTypeFullName(expr: Expression): Option[String] = + + val resolvedTypeOption = tryWithSafeStackOverflow(expr.calculateResolvedType()) match + case Failure(ex) => + ex match + // If ast parser fails to resolve type, try resolving locally by using name + // Precaution when resolving by name, we only want to resolve for case when the expr is solely a MethodCallExpr + // and doesn't have a scope to it + case symbolException: UnsolvedSymbolException => + expr match + case callExpr: MethodCallExpr => + callExpr.getScope.toScala match + case Some(_: Expression) => None + case _ => scope.lookupType(symbolException.getName) + case _ => None + case _ => None + case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) + resolvedTypeOption.orElse(exprNameFromStack(expr)) + + private def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = + val fallbackType = s"${Defines.UnresolvedNamespace}.${annotationExpr.getNameAsString}" + val fullName = expressionReturnTypeFullName(annotationExpr).getOrElse(fallbackType) + val code = annotationExpr.toString + val name = annotationExpr.getName.getIdentifier + val node = annotationNode(annotationExpr, code, name, fullName) + annotationExpr match + case _: MarkerAnnotationExpr => + annotationAst(node, List.empty) + case normal: NormalAnnotationExpr => + val assignmentAsts = normal.getPairs.asScala.toList.map { pair => + annotationAssignmentAst( + pair.getName.getIdentifier, + pair.toString, + convertAnnotationValueExpr(pair.getValue).getOrElse(Ast()) + ) + } + annotationAst(node, assignmentAsts) + case single: SingleMemberAnnotationExpr => + val assignmentAsts = List( + annotationAssignmentAst( + "value", + single.getMemberValue.toString, + convertAnnotationValueExpr(single.getMemberValue).getOrElse(Ast()) + ) + ) + annotationAst(node, assignmentAsts) + end match + end astForAnnotationExpr + + private def abstractModifierForCallable( + callableDeclaration: CallableDeclaration[?], + isInterfaceMethod: Boolean + ): Option[NewModifier] = + callableDeclaration match + case methodDeclaration: MethodDeclaration => + Option.when( + methodDeclaration.isAbstract || (isInterfaceMethod && !methodDeclaration.isDefault) + ) { + newModifierNode(ModifierTypes.ABSTRACT) + } + + case _ => None + + private def modifiersForMethod(methodDeclaration: CallableDeclaration[?]): List[NewModifier] = + val isInterfaceMethod = scope.enclosingTypeDecl.exists(_.code.contains("interface ")) + + val abstractModifier = abstractModifierForCallable(methodDeclaration, isInterfaceMethod) + + val staticVirtualModifierType = + if methodDeclaration.isStatic then ModifierTypes.STATIC else ModifierTypes.VIRTUAL + val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) + + val accessModifierType = if methodDeclaration.isPublic then + Some(ModifierTypes.PUBLIC) + else if methodDeclaration.isPrivate then + Some(ModifierTypes.PRIVATE) + else if methodDeclaration.isProtected then + Some(ModifierTypes.PROTECTED) + else if isInterfaceMethod then + // TODO: more robust interface check + Some(ModifierTypes.PUBLIC) + else + None + val accessModifier = accessModifierType.map(newModifierNode) + + List(accessModifier, abstractModifier, staticVirtualModifier).flatten + end modifiersForMethod + + private def getIdentifiersForTypeParameters(methodDeclaration: MethodDeclaration) + : List[NewIdentifier] = + methodDeclaration.getTypeParameters.asScala.map { typeParameter => + val name = typeParameter.getNameAsString + val typeFullName = typeParameter.getTypeBound.asScala.headOption + .flatMap(typeInfoCalc.fullName) + .getOrElse(TypeConstants.Object) + typeInfoCalc.registerType(typeFullName) + + NewIdentifier().name(name).typeFullName(typeFullName) + }.toList + + private def astForMethod(methodDeclaration: MethodDeclaration): Ast = + val methodNode = createPartialMethod(methodDeclaration) + + val typeParameters = getIdentifiersForTypeParameters(methodDeclaration) + + val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()) + val expectedReturnType = Try(symbolSolver.toResolvedType( + methodDeclaration.getType, + classOf[ResolvedType] + )).toOption + val simpleMethodReturnType = methodDeclaration.getTypeAsString() + val returnTypeFullName = expectedReturnType + .flatMap(typeInfoCalc.fullName) + .orElse(scope.lookupType(simpleMethodReturnType)) + .orElse(typeParameters.find(_.name == simpleMethodReturnType).map(_.typeFullName)) + + scope.pushMethodScope(methodNode, ExpectedType(returnTypeFullName, expectedReturnType)) + typeParameters.foreach { typeParameter => + scope.addType(typeParameter.name, typeParameter.typeFullName) } - case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) - } - resolvedTypeOption.orElse(exprNameFromStack(expr)) - } - - private def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = { - val fallbackType = s"${Defines.UnresolvedNamespace}.${annotationExpr.getNameAsString}" - val fullName = expressionReturnTypeFullName(annotationExpr).getOrElse(fallbackType) - val code = annotationExpr.toString - val name = annotationExpr.getName.getIdentifier - val node = annotationNode(annotationExpr, code, name, fullName) - annotationExpr match { - case _: MarkerAnnotationExpr => - annotationAst(node, List.empty) - case normal: NormalAnnotationExpr => - val assignmentAsts = normal.getPairs.asScala.toList.map { pair => - annotationAssignmentAst( - pair.getName.getIdentifier, - pair.toString, - convertAnnotationValueExpr(pair.getValue).getOrElse(Ast()) - ) + + val parameterAsts = astsForParameterList(methodDeclaration.getParameters) + val parameterTypes = argumentTypesForMethodLike(maybeResolved) + val signature = composeSignature(returnTypeFullName, parameterTypes, parameterAsts.size) + val namespaceName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val methodFullName = + composeMethodFullName(namespaceName, methodDeclaration.getNameAsString, signature) + + methodNode + .fullName(methodFullName) + .signature(signature) + val typeNameLookup = + methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + if packagesJarMappings.contains(typeNameLookup) then + methodNode.astParentType(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).head) + val thisNode = Option.when(!methodDeclaration.isStatic) { + val typeFullName = scope.enclosingTypeDeclFullName + thisNodeForMethod(typeFullName, line(methodDeclaration)) } - annotationAst(node, assignmentAsts) - case single: SingleMemberAnnotationExpr => - val assignmentAsts = List( - annotationAssignmentAst( - "value", - single.getMemberValue.toString, - convertAnnotationValueExpr(single.getMemberValue).getOrElse(Ast()) - ) + val thisAst = thisNode.map(Ast(_)).toList + + thisNode.foreach { node => + scope.addParameter(node) + } + + val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast( + NewBlock() + )) + val methodReturn = newMethodReturnNode( + returnTypeFullName.getOrElse(TypeConstants.Any), + None, + line(methodDeclaration.getType), + column(methodDeclaration.getType) + ) + + val annotationAsts = + methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq + + val modifiers = modifiersForMethod(methodDeclaration) + + scope.popScope() + + methodAstWithAnnotations( + methodNode, + thisAst ++ parameterAsts, + bodyAst, + methodReturn, + modifiers, + annotationAsts ) - annotationAst(node, assignmentAsts) - } - } - - private def abstractModifierForCallable( - callableDeclaration: CallableDeclaration[_], - isInterfaceMethod: Boolean - ): Option[NewModifier] = { - callableDeclaration match { - case methodDeclaration: MethodDeclaration => - Option.when(methodDeclaration.isAbstract || (isInterfaceMethod && !methodDeclaration.isDefault)) { - newModifierNode(ModifierTypes.ABSTRACT) + end astForMethod + + private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration) + : NewMethodReturn = + val line = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.line)).toScala + val column = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.column)).toScala + newMethodReturnNode(TypeConstants.Void, None, line, column) + + /** Constructor and Method declarations share a lot of fields, so this method adds the fields + * they have in common. `fullName` and `signature` are omitted + */ + private def createPartialMethod(declaration: CallableDeclaration[?]): NewMethod = + val code = declaration.getDeclarationAsString.trim + val columnNumber = declaration.getBegin.map(x => Integer.valueOf(x.column)).toScala + val endLine = declaration.getEnd.map(x => Integer.valueOf(x.line)).toScala + val endColumn = declaration.getEnd.map(x => Integer.valueOf(x.column)).toScala + + val methodNode = NewMethod() + .name(declaration.getNameAsString) + .code(code) + .isExternal(false) + .filename(filename) + .lineNumber(line(declaration)) + .columnNumber(columnNumber) + .lineNumberEnd(endLine) + .columnNumberEnd(endColumn) + + methodNode + + private def astForConstructorBody(body: Option[BlockStmt]): Ast = + val containsThisInvocation = + body + .flatMap(_.getStatements.asScala.headOption) + .collect { case e: ExplicitConstructorInvocationStmt => e } + .exists(_.isThis) + + val memberInitializers = + if containsThisInvocation then + Seq.empty + else + scope.memberInitializers + + body match + case Some(b) => astForBlockStatement(b, prefixAsts = memberInitializers) + + case None => Ast(NewBlock()).withChildren(memberInitializers) + + private def astsForLabeledStatement(stmt: LabeledStmt): Seq[Ast] = + val jumpTargetAst = Ast(NewJumpTarget().name(stmt.getLabel.toString)) + val stmtAst = astsForStatement(stmt.getStatement).toList + + jumpTargetAst :: stmtAst + + private def astForThrow(stmt: ThrowStmt): Ast = + val throwNode = NewCall() + .name(".throw") + .methodFullName(".throw") + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString()) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + + val args = astsForExpression(stmt.getExpression, ExpectedType.empty) + + callAst(throwNode, args) + + private def astForCatchClause(catchClause: CatchClause): Ast = + astForBlockStatement(catchClause.getBody) + + private def astsForTry(stmt: TryStmt): Seq[Ast] = + val tryNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.TRY) + .code("try") + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val resources = stmt.getResources.asScala.flatMap(astsForExpression( + _, + expectedType = ExpectedType.empty + )).toList + val tryAst = astForBlockStatement(stmt.getTryBlock, codeStr = "try") + val catchAsts = stmt.getCatchClauses.asScala.map(astForCatchClause) + val catchBlock = Option + .when(catchAsts.nonEmpty) { + Ast(NewBlock().code("catch")).withChildren(catchAsts) + } + .toList + val finallyAst = + stmt.getFinallyBlock.toScala.map(astForBlockStatement(_, "finally")).toList + + val controlStructureAst = Ast(tryNode) + .withChild(tryAst) + .withChildren(catchBlock) + .withChildren(finallyAst) + + resources.appended(controlStructureAst) + end astsForTry + + private def astsForStatement(statement: Statement): Seq[Ast] = + // TODO: Implement missing handlers + // case _: LocalClassDeclarationStmt => Seq() + // case _: LocalRecordDeclarationStmt => Seq() + // case _: YieldStmt => Seq() + statement match + case x: ExplicitConstructorInvocationStmt => + Seq(astForExplicitConstructorInvocation(x)) + case x: AssertStmt => Seq(astForAssertStatement(x)) + case x: BlockStmt => Seq(astForBlockStatement(x)) + case x: BreakStmt => Seq(astForBreakStatement(x)) + case x: ContinueStmt => Seq(astForContinueStatement(x)) + case x: DoStmt => Seq(astForDo(x)) + case _: EmptyStmt => Seq() // Intentionally skipping this + case x: ExpressionStmt => astsForExpression(x.getExpression, ExpectedType.Void) + case x: ForEachStmt => astForForEach(x) + case x: ForStmt => Seq(astForFor(x)) + case x: IfStmt => Seq(astForIf(x)) + case x: LabeledStmt => astsForLabeledStatement(x) + case x: ReturnStmt => Seq(astForReturnNode(x)) + case x: SwitchStmt => Seq(astForSwitchStatement(x)) + case x: SynchronizedStmt => Seq(astForSynchronizedStatement(x)) + case x: ThrowStmt => Seq(astForThrow(x)) + case x: TryStmt => astsForTry(x) + case x: WhileStmt => Seq(astForWhile(x)) + // case x: LocalClassDeclarationStmt => Seq(astForLocalClassDeclarationStmt(x)) + case x => + logger.debug( + s"Attempting to generate AST for unknown statement of type ${x.getClass}" + ) + Seq(unknownAst(x)) + + private def astForElse(maybeStmt: Option[Statement]): Option[Ast] = + maybeStmt.map { stmt => + val elseAsts = astsForStatement(stmt) + + val elseNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.ELSE) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code("else") + + Ast(elseNode).withChildren(elseAsts) + } + + def astForIf(stmt: IfStmt): Ast = + val ifNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.IF) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(s"if (${stmt.getCondition.toString})") + + val conditionAst = + astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList + + val thenAsts = astsForStatement(stmt.getThenStmt) + val elseAst = astForElse(stmt.getElseStmt.toScala).toList + + val ast = Ast(ifNode) + .withChildren(conditionAst) + .withChildren(thenAsts) + .withChildren(elseAst) + + conditionAst.flatMap(_.root.toList) match + case r :: Nil => + ast.withConditionEdge(ifNode, r) + case _ => + ast + end astForIf + + def astForWhile(stmt: WhileStmt): Ast = + val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption + val stmtAsts = astsForStatement(stmt.getBody) + val code = s"while (${stmt.getCondition.toString})" + val lineNumber = line(stmt) + val columnNumber = column(stmt) + + whileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + + private def astForDo(stmt: DoStmt): Ast = + val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption + val stmtAsts = astsForStatement(stmt.getBody) + val code = s"do {...} while (${stmt.getCondition.toString})" + val lineNumber = line(stmt) + val columnNumber = column(stmt) + + doWhileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + + private def astForBreakStatement(stmt: BreakStmt): Ast = + val node = NewControlStructure() + .controlStructureType(ControlStructureTypes.BREAK) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString) + Ast(node) + + private def astForContinueStatement(stmt: ContinueStmt): Ast = + val node = NewControlStructure() + .controlStructureType(ControlStructureTypes.CONTINUE) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString) + Ast(node) + + private def getForCode(stmt: ForStmt): String = + val init = stmt.getInitialization.asScala.map(_.toString).mkString(", ") + val compare = stmt.getCompare.toScala.map(_.toString) + val update = stmt.getUpdate.asScala.map(_.toString).mkString(", ") + s"for ($init; $compare; $update)" + def astForFor(stmt: ForStmt): Ast = + val forNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.FOR) + .code(getForCode(stmt)) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val initAsts = + stmt.getInitialization.asScala.flatMap(astsForExpression( + _, + expectedType = ExpectedType.empty + )) + + val compareAsts = stmt.getCompare.toScala.toList.flatMap { + astsForExpression(_, ExpectedType.Boolean) + } + + val updateAsts = stmt.getUpdate.asScala.toList.flatMap { + astsForExpression(_, ExpectedType.empty) } - case _ => None - } - } - - private def modifiersForMethod(methodDeclaration: CallableDeclaration[_]): List[NewModifier] = { - val isInterfaceMethod = scope.enclosingTypeDecl.exists(_.code.contains("interface ")) - - val abstractModifier = abstractModifierForCallable(methodDeclaration, isInterfaceMethod) - - val staticVirtualModifierType = if (methodDeclaration.isStatic) ModifierTypes.STATIC else ModifierTypes.VIRTUAL - val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) - - val accessModifierType = if (methodDeclaration.isPublic) { - Some(ModifierTypes.PUBLIC) - } else if (methodDeclaration.isPrivate) { - Some(ModifierTypes.PRIVATE) - } else if (methodDeclaration.isProtected) { - Some(ModifierTypes.PROTECTED) - } else if (isInterfaceMethod) { - // TODO: more robust interface check - Some(ModifierTypes.PUBLIC) - } else { - None - } - val accessModifier = accessModifierType.map(newModifierNode) - - List(accessModifier, abstractModifier, staticVirtualModifier).flatten - } - - private def getIdentifiersForTypeParameters(methodDeclaration: MethodDeclaration): List[NewIdentifier] = { - methodDeclaration.getTypeParameters.asScala.map { typeParameter => - val name = typeParameter.getNameAsString - val typeFullName = typeParameter.getTypeBound.asScala.headOption - .flatMap(typeInfoCalc.fullName) - .getOrElse(TypeConstants.Object) - typeInfoCalc.registerType(typeFullName) - - NewIdentifier().name(name).typeFullName(typeFullName) - }.toList - } - - private def astForMethod(methodDeclaration: MethodDeclaration): Ast = { - val methodNode = createPartialMethod(methodDeclaration) - - val typeParameters = getIdentifiersForTypeParameters(methodDeclaration) - - val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()) - val expectedReturnType = Try(symbolSolver.toResolvedType(methodDeclaration.getType, classOf[ResolvedType])).toOption - val simpleMethodReturnType = methodDeclaration.getTypeAsString() - val returnTypeFullName = expectedReturnType - .flatMap(typeInfoCalc.fullName) - .orElse(scope.lookupType(simpleMethodReturnType)) - .orElse(typeParameters.find(_.name == simpleMethodReturnType).map(_.typeFullName)) - - scope.pushMethodScope(methodNode, ExpectedType(returnTypeFullName, expectedReturnType)) - typeParameters.foreach { typeParameter => scope.addType(typeParameter.name, typeParameter.typeFullName) } - - val parameterAsts = astsForParameterList(methodDeclaration.getParameters) - val parameterTypes = argumentTypesForMethodLike(maybeResolved) - val signature = composeSignature(returnTypeFullName, parameterTypes, parameterAsts.size) - val namespaceName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val methodFullName = composeMethodFullName(namespaceName, methodDeclaration.getNameAsString, signature) - - methodNode - .fullName(methodFullName) - .signature(signature) - val typeNameLookup = methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - if (packagesJarMappings.contains(typeNameLookup)) - methodNode.astParentType(packagesJarMappings.getOrElse(typeNameLookup, mutable.Set.empty).head) - val thisNode = Option.when(!methodDeclaration.isStatic) { - val typeFullName = scope.enclosingTypeDeclFullName - thisNodeForMethod(typeFullName, line(methodDeclaration)) - } - val thisAst = thisNode.map(Ast(_)).toList - - thisNode.foreach { node => - scope.addParameter(node) - } - - val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast(NewBlock())) - val methodReturn = newMethodReturnNode( - returnTypeFullName.getOrElse(TypeConstants.Any), - None, - line(methodDeclaration.getType), - column(methodDeclaration.getType) - ) - - val annotationAsts = methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq - - val modifiers = modifiersForMethod(methodDeclaration) - - scope.popScope() - - methodAstWithAnnotations(methodNode, thisAst ++ parameterAsts, bodyAst, methodReturn, modifiers, annotationAsts) - } - - private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration): NewMethodReturn = { - val line = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.line)).toScala - val column = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.column)).toScala - newMethodReturnNode(TypeConstants.Void, None, line, column) - } - - /** Constructor and Method declarations share a lot of fields, so this method adds the fields they have in common. - * `fullName` and `signature` are omitted - */ - private def createPartialMethod(declaration: CallableDeclaration[_]): NewMethod = { - val code = declaration.getDeclarationAsString.trim - val columnNumber = declaration.getBegin.map(x => Integer.valueOf(x.column)).toScala - val endLine = declaration.getEnd.map(x => Integer.valueOf(x.line)).toScala - val endColumn = declaration.getEnd.map(x => Integer.valueOf(x.column)).toScala - - val methodNode = NewMethod() - .name(declaration.getNameAsString) - .code(code) - .isExternal(false) - .filename(filename) - .lineNumber(line(declaration)) - .columnNumber(columnNumber) - .lineNumberEnd(endLine) - .columnNumberEnd(endColumn) - - methodNode - } - - private def astForConstructorBody(body: Option[BlockStmt]): Ast = { - val containsThisInvocation = - body - .flatMap(_.getStatements.asScala.headOption) - .collect { case e: ExplicitConstructorInvocationStmt => e } - .exists(_.isThis) - - val memberInitializers = - if (containsThisInvocation) - Seq.empty - else - scope.memberInitializers - - body match { - case Some(b) => astForBlockStatement(b, prefixAsts = memberInitializers) - - case None => Ast(NewBlock()).withChildren(memberInitializers) - } - } - - private def astsForLabeledStatement(stmt: LabeledStmt): Seq[Ast] = { - val jumpTargetAst = Ast(NewJumpTarget().name(stmt.getLabel.toString)) - val stmtAst = astsForStatement(stmt.getStatement).toList - - jumpTargetAst :: stmtAst - } - - private def astForThrow(stmt: ThrowStmt): Ast = { - val throwNode = NewCall() - .name(".throw") - .methodFullName(".throw") - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(stmt.toString()) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - - val args = astsForExpression(stmt.getExpression, ExpectedType.empty) - - callAst(throwNode, args) - } - - private def astForCatchClause(catchClause: CatchClause): Ast = { - astForBlockStatement(catchClause.getBody) - } - - private def astsForTry(stmt: TryStmt): Seq[Ast] = { - val tryNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.TRY) - .code("try") - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - val resources = stmt.getResources.asScala.flatMap(astsForExpression(_, expectedType = ExpectedType.empty)).toList - val tryAst = astForBlockStatement(stmt.getTryBlock, codeStr = "try") - val catchAsts = stmt.getCatchClauses.asScala.map(astForCatchClause) - val catchBlock = Option - .when(catchAsts.nonEmpty) { - Ast(NewBlock().code("catch")).withChildren(catchAsts) - } - .toList - val finallyAst = - stmt.getFinallyBlock.toScala.map(astForBlockStatement(_, "finally")).toList - - val controlStructureAst = Ast(tryNode) - .withChild(tryAst) - .withChildren(catchBlock) - .withChildren(finallyAst) - - resources.appended(controlStructureAst) - } - - private def astsForStatement(statement: Statement): Seq[Ast] = { - // TODO: Implement missing handlers - // case _: LocalClassDeclarationStmt => Seq() - // case _: LocalRecordDeclarationStmt => Seq() - // case _: YieldStmt => Seq() - statement match { - case x: ExplicitConstructorInvocationStmt => - Seq(astForExplicitConstructorInvocation(x)) - case x: AssertStmt => Seq(astForAssertStatement(x)) - case x: BlockStmt => Seq(astForBlockStatement(x)) - case x: BreakStmt => Seq(astForBreakStatement(x)) - case x: ContinueStmt => Seq(astForContinueStatement(x)) - case x: DoStmt => Seq(astForDo(x)) - case _: EmptyStmt => Seq() // Intentionally skipping this - case x: ExpressionStmt => astsForExpression(x.getExpression, ExpectedType.Void) - case x: ForEachStmt => astForForEach(x) - case x: ForStmt => Seq(astForFor(x)) - case x: IfStmt => Seq(astForIf(x)) - case x: LabeledStmt => astsForLabeledStatement(x) - case x: ReturnStmt => Seq(astForReturnNode(x)) - case x: SwitchStmt => Seq(astForSwitchStatement(x)) - case x: SynchronizedStmt => Seq(astForSynchronizedStatement(x)) - case x: ThrowStmt => Seq(astForThrow(x)) - case x: TryStmt => astsForTry(x) - case x: WhileStmt => Seq(astForWhile(x)) - // case x: LocalClassDeclarationStmt => Seq(astForLocalClassDeclarationStmt(x)) - case x => - logger.debug(s"Attempting to generate AST for unknown statement of type ${x.getClass}") - Seq(unknownAst(x)) - } - } - - private def astForElse(maybeStmt: Option[Statement]): Option[Ast] = { - maybeStmt.map { stmt => - val elseAsts = astsForStatement(stmt) - - val elseNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.ELSE) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code("else") - - Ast(elseNode).withChildren(elseAsts) - } - } - - def astForIf(stmt: IfStmt): Ast = { - val ifNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.IF) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(s"if (${stmt.getCondition.toString})") - - val conditionAst = - astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList - - val thenAsts = astsForStatement(stmt.getThenStmt) - val elseAst = astForElse(stmt.getElseStmt.toScala).toList - - val ast = Ast(ifNode) - .withChildren(conditionAst) - .withChildren(thenAsts) - .withChildren(elseAst) - - conditionAst.flatMap(_.root.toList) match { - case r :: Nil => - ast.withConditionEdge(ifNode, r) - case _ => + val stmtAsts = + astsForStatement(stmt.getBody) + + val ast = Ast(forNode) + .withChildren(initAsts) + .withChildren(compareAsts) + .withChildren(updateAsts) + .withChildren(stmtAsts) + + compareAsts.flatMap(_.root) match + case c :: Nil => + ast.withConditionEdge(forNode, c) + case _ => ast + end astForFor + + private def iterableAssignAstsForNativeForEach( + iterableExpression: Expression, + iterableType: Option[String] + ): (NodeTypeInfo, Seq[Ast]) = + val lineNo = line(iterableExpression) + val expectedType = ExpectedType(iterableType) + + val iterableAst = astsForExpression(iterableExpression, expectedType = expectedType) match + case Nil => + logger.debug( + s"Could not create AST for iterable expr $iterableExpression: $filename:l$lineNo" + ) + Ast() + case iterableAstHead :: Nil => iterableAstHead + case iterableAsts => + logger.debug( + s"Found multiple ASTS for iterable expr $iterableExpression: $filename:l$lineNo\nDropping all but the first!" + ) + iterableAsts.head + + val iterableName = nextIterableName() + val iterableLocalNode = + localNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) + val iterableLocalAst = Ast(iterableLocalNode) + + val iterableAssignNode = + newOperatorCallNode( + Operators.assignment, + code = "", + line = lineNo, + typeFullName = iterableType + ) + val iterableAssignIdentifier = + identifierNode( + iterableExpression, + iterableName, + iterableName, + iterableType.getOrElse("ANY") + ) + val iterableAssignArgs = List(Ast(iterableAssignIdentifier), iterableAst) + val iterableAssignAst = + callAst(iterableAssignNode, iterableAssignArgs) + .withRefEdge(iterableAssignIdentifier, iterableLocalNode) + + ( + NodeTypeInfo( + iterableLocalNode, + iterableLocalNode.name, + Some(iterableLocalNode.typeFullName) + ), + List(iterableLocalAst, iterableAssignAst) + ) + end iterableAssignAstsForNativeForEach + + private def nativeForEachIdxLocalNode(lineNo: Option[Integer]): NewLocal = + val idxName = nextIndexName() + val typeFullName = TypeConstants.Int + val idxLocal = + NewLocal() + .name(idxName) + .typeFullName(typeFullName) + .code(idxName) + .lineNumber(lineNo) + scope.addLocal(idxLocal) + idxLocal + + private def nativeForEachIdxInitializerAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = + val idxName = idxLocal.name + val idxInitializerCallNode = newOperatorCallNode( + Operators.assignment, + code = s"int $idxName = 0", + line = lineNo, + typeFullName = Some(TypeConstants.Int) + ) + val idxIdentifierArg = newIdentifierNode(idxName, idxLocal.typeFullName) + val zeroLiteral = + NewLiteral() + .code("0") + .typeFullName(TypeConstants.Int) + .lineNumber(lineNo) + val idxInitializerArgAsts = List(Ast(idxIdentifierArg), Ast(zeroLiteral)) + callAst(idxInitializerCallNode, idxInitializerArgAsts) + .withRefEdge(idxIdentifierArg, idxLocal) + + private def nativeForEachCompareAst( + lineNo: Option[Integer], + iterableSource: NodeTypeInfo, + idxLocal: NewLocal + ): Ast = + val idxName = idxLocal.name + + val compareNode = newOperatorCallNode( + Operators.lessThan, + code = s"$idxName < ${iterableSource.name}.length", + typeFullName = Some(TypeConstants.Boolean), + line = lineNo + ) + val comparisonIdxIdentifier = newIdentifierNode(idxName, idxLocal.typeFullName) + val comparisonFieldAccess = newOperatorCallNode( + Operators.fieldAccess, + code = s"${iterableSource.name}.length", + typeFullName = Some(TypeConstants.Int), + line = lineNo + ) + val fieldAccessIdentifier = + newIdentifierNode(iterableSource.name, iterableSource.typeFullName.getOrElse("ANY")) + val fieldAccessFieldIdentifier = newFieldIdentifierNode("length", lineNo) + val fieldAccessArgs = List(fieldAccessIdentifier, fieldAccessFieldIdentifier).map(Ast(_)) + val fieldAccessAst = callAst(comparisonFieldAccess, fieldAccessArgs) + val compareArgs = List(Ast(comparisonIdxIdentifier), fieldAccessAst) + + // TODO: This is a workaround for a crash when looping over statically imported members. Handle those properly. + val iterableSourceNode = localParamOrMemberFromNode(iterableSource) + + callAst(compareNode, compareArgs) + .withRefEdge(comparisonIdxIdentifier, idxLocal) + .withRefEdges(fieldAccessIdentifier, iterableSourceNode.toList) + end nativeForEachCompareAst + + private def nativeForEachIncrementAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = + val incrementNode = newOperatorCallNode( + Operators.postIncrement, + code = s"${idxLocal.name}++", + typeFullName = Some(TypeConstants.Int), + line = lineNo + ) + val incrementArg = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) + val incrementArgAst = Ast(incrementArg) + callAst(incrementNode, List(incrementArgAst)) + .withRefEdge(incrementArg, idxLocal) + + private def variableLocalForForEachBody(stmt: ForEachStmt): NewLocal = + val lineNo = line(stmt) + // Create item local + val maybeVariable = stmt.getVariable.getVariables.asScala.toList match + case Nil => + logger.debug(s"ForEach statement has empty variable list: $filename$lineNo") + None + case variable :: Nil => Some(variable) + case variable :: _ => + logger.debug( + s"ForEach statement defines multiple variables. Dropping all but the first: $filename$lineNo" + ) + Some(variable) + + val partialLocalNode = NewLocal().lineNumber(lineNo) + + maybeVariable match + case Some(variable) => + val name = variable.getNameAsString + val typeFullName = typeInfoCalc.fullName(variable.getType).getOrElse("ANY") + val localNode = partialLocalNode + .name(name) + .code(variable.getNameAsString) + .typeFullName(typeFullName) + + scope.addLocal(localNode) + localNode + + case None => + // Returning partialLocalNode here is fine since getting to this case means everything is broken anyways :) + partialLocalNode + end variableLocalForForEachBody + + private def localParamOrMemberFromNode(nodeTypeInfo: NodeTypeInfo): Option[NewNode] = + nodeTypeInfo.node match + case localNode: NewLocal => Some(localNode) + case memberNode: NewMember => Some(memberNode) + case parameterNode: NewMethodParameterIn => Some(parameterNode) + case _ => None + private def variableAssignForNativeForEachBody( + variableLocal: NewLocal, + idxLocal: NewLocal, + iterable: NodeTypeInfo + ): Ast = + // Everything will be on the same line as the `for` statement, but this is the most useful + // solution for debugging. + val lineNo = variableLocal.lineNumber + val varAssignNode = + newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(variableLocal.typeFullName), + lineNo + ) + + val targetNode = newIdentifierNode(variableLocal.name, variableLocal.typeFullName) + + val indexAccessTypeFullName = iterable.typeFullName.map(_.replaceAll(raw"\[]", "")) + val indexAccess = + newOperatorCallNode( + Operators.indexAccess, + PropertyDefaults.Code, + indexAccessTypeFullName, + lineNo + ) + + val indexAccessIdentifier = + newIdentifierNode(iterable.name, iterable.typeFullName.getOrElse("ANY")) + val indexAccessIndex = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) + + val indexAccessArgsAsts = List(indexAccessIdentifier, indexAccessIndex).map(Ast(_)) + val indexAccessAst = callAst(indexAccess, indexAccessArgsAsts) + + val iterableSourceNode = localParamOrMemberFromNode(iterable) + + val assignArgsAsts = List(Ast(targetNode), indexAccessAst) + callAst(varAssignNode, assignArgsAsts) + .withRefEdge(targetNode, variableLocal) + .withRefEdges(indexAccessIdentifier, iterableSourceNode.toList) + .withRefEdge(indexAccessIndex, idxLocal) + end variableAssignForNativeForEachBody + + private def nativeForEachBodyAst( + stmt: ForEachStmt, + idxLocal: NewLocal, + iterable: NodeTypeInfo + ): Ast = + val variableLocal = variableLocalForForEachBody(stmt) + val variableLocalAst = Ast(variableLocal) + val variableAssignAst = + variableAssignForNativeForEachBody(variableLocal, idxLocal, iterable) + + stmt.getBody match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = List(variableLocalAst, variableAssignAst)) + case statement => + val stmtAsts = astsForStatement(statement) + val blockNode = NewBlock().lineNumber(variableLocal.lineNumber) + Ast(blockNode) + .withChild(variableLocalAst) + .withChild(variableAssignAst) + .withChildren(stmtAsts) + end nativeForEachBodyAst + + private def astsForNativeForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = + + // This is ugly, but for a case like `for (int x : new int[] { ... })` this creates a new LOCAL + // with the assignment `int[] $iterLocal0 = new int[] { ... }` before the FOR loop. + // TODO: Fix this + val (iterableSource: NodeTypeInfo, tempIterableInitAsts) = stmt.getIterable match + case nameExpr: NameExpr => + scope.lookupVariable(nameExpr.getNameAsString).asNodeInfoOption match + // If this is not the case, then the code is broken (iterable not in scope). + case Some(nodeTypeInfo) => (nodeTypeInfo, Nil) + case _ => iterableAssignAstsForNativeForEach(nameExpr, iterableType) + case iterableExpr => iterableAssignAstsForNativeForEach(iterableExpr, iterableType) + + val forNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.FOR) + + val lineNo = line(stmt) + + val idxLocal = nativeForEachIdxLocalNode(lineNo) + val idxInitializerAst = nativeForEachIdxInitializerAst(lineNo, idxLocal) + // TODO next: pass NodeTypeInfo around + val compareAst = nativeForEachCompareAst(lineNo, iterableSource, idxLocal) + val incrementAst = nativeForEachIncrementAst(lineNo, idxLocal) + val bodyAst = nativeForEachBodyAst(stmt, idxLocal, iterableSource) + + val forAst = Ast(forNode) + .withChild(Ast(idxLocal)) + .withChild(idxInitializerAst) + .withChild(compareAst) + .withChild(incrementAst) + .withChild(bodyAst) + .withConditionEdges(forNode, compareAst.root.toList) + + tempIterableInitAsts ++ Seq(forAst) + end astsForNativeForEach + + private def iteratorLocalForForEach(lineNumber: Option[Integer]): NewLocal = + val iteratorLocalName = nextIterableName() + NewLocal() + .name(iteratorLocalName) + .code(iteratorLocalName) + .typeFullName(TypeConstants.Iterator) + .lineNumber(lineNumber) + + private def iteratorAssignAstForForEach( + iterExpr: Expression, + iteratorLocalNode: NewLocal, + iterableType: Option[String], + lineNo: Option[Integer] + ): Ast = + val iteratorAssignNode = + newOperatorCallNode( + Operators.assignment, + code = "", + typeFullName = Some(TypeConstants.Iterator), + line = lineNo + ) + val iteratorAssignIdentifier = + identifierNode( + iterExpr, + iteratorLocalNode.name, + iteratorLocalNode.name, + iteratorLocalNode.typeFullName + ) + + val iteratorCallNode = + newCallNode( + "iterator", + iterableType, + TypeConstants.Iterator, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) + + val actualIteratorAst = + astsForExpression(iterExpr, expectedType = ExpectedType.empty).toList match + case Nil => + logger.debug(s"Could not create receiver ast for iterator $iterExpr") + None + + case ast :: Nil => Some(ast) + + case ast :: _ => + logger.debug( + s"Created multiple receiver asts for $iterExpr. Dropping all but the first." + ) + Some(ast) + + val iteratorCallAst = + callAst(iteratorCallNode, base = actualIteratorAst) + + callAst(iteratorAssignNode, List(Ast(iteratorAssignIdentifier), iteratorCallAst)) + .withRefEdge(iteratorAssignIdentifier, iteratorLocalNode) + end iteratorAssignAstForForEach + + private def hasNextCallAstForForEach( + iteratorLocalNode: NewLocal, + lineNo: Option[Integer] + ): Ast = + val iteratorHasNextCallNode = + newCallNode( + "hasNext", + Some(TypeConstants.Iterator), + TypeConstants.Boolean, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) + val iteratorHasNextCallReceiver = + newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) + + callAst(iteratorHasNextCallNode, base = Some(Ast(iteratorHasNextCallReceiver))) + .withRefEdge(iteratorHasNextCallReceiver, iteratorLocalNode) + + private def astForIterableForEachItemAssign( + iteratorLocalNode: NewLocal, + variableLocal: NewLocal + ): Ast = + val lineNo = variableLocal.lineNumber + val forVariableType = variableLocal.typeFullName + val varLocalAssignNode = + newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(forVariableType), + lineNo + ) + val varLocalAssignIdentifier = + newIdentifierNode(variableLocal.name, variableLocal.typeFullName) + + val iterNextCallNode = + newCallNode( + "next", + Some(TypeConstants.Iterator), + TypeConstants.Object, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) + val iterNextCallReceiver = + newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) + val iterNextCallAst = + callAst(iterNextCallNode, base = Some(Ast(iterNextCallReceiver))) + .withRefEdge(iterNextCallReceiver, iteratorLocalNode) + + callAst(varLocalAssignNode, List(Ast(varLocalAssignIdentifier), iterNextCallAst)) + .withRefEdge(varLocalAssignIdentifier, variableLocal) + end astForIterableForEachItemAssign + + private def astForIterableForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = + val lineNo = line(stmt) + + val iteratorLocalNode = iteratorLocalForForEach(lineNo) + val iteratorAssignAst = + iteratorAssignAstForForEach(stmt.getIterable, iteratorLocalNode, iterableType, lineNo) + val iteratorHasNextCallAst = hasNextCallAstForForEach(iteratorLocalNode, lineNo) + val variableLocal = variableLocalForForEachBody(stmt) + val variableAssignAst = astForIterableForEachItemAssign(iteratorLocalNode, variableLocal) + + val bodyPrefixAsts = Seq(Ast(variableLocal), variableAssignAst) + val bodyAst = stmt.getBody match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = bodyPrefixAsts) + + case bodyStmt => + val bodyBlockNode = NewBlock().lineNumber(lineNo) + val bodyStmtAsts = astsForStatement(bodyStmt) + Ast(bodyBlockNode) + .withChildren(bodyPrefixAsts) + .withChildren(bodyStmtAsts) + + val forNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.WHILE) + .code(ControlStructureTypes.FOR) + .lineNumber(lineNo) + .columnNumber(column(stmt)) + + val forAst = controlStructureAst(forNode, Some(iteratorHasNextCallAst), List(bodyAst)) + + Seq(Ast(iteratorLocalNode), iteratorAssignAst, forAst) + end astForIterableForEach + + private def astForForEach(stmt: ForEachStmt): Seq[Ast] = + scope.pushBlockScope() + + val ast = expressionReturnTypeFullName(stmt.getIterable) match + case Some(typeFullName) if typeFullName.endsWith("[]") => + astsForNativeForEach(stmt, Some(typeFullName)) + + case maybeType => + astForIterableForEach(stmt, maybeType) + + scope.popScope() ast - } - } - - def astForWhile(stmt: WhileStmt): Ast = { - val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) - val code = s"while (${stmt.getCondition.toString})" - val lineNumber = line(stmt) - val columnNumber = column(stmt) - - whileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) - } - - private def astForDo(stmt: DoStmt): Ast = { - val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) - val code = s"do {...} while (${stmt.getCondition.toString})" - val lineNumber = line(stmt) - val columnNumber = column(stmt) - - doWhileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) - } - - private def astForBreakStatement(stmt: BreakStmt): Ast = { - val node = NewControlStructure() - .controlStructureType(ControlStructureTypes.BREAK) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(stmt.toString) - Ast(node) - } - - private def astForContinueStatement(stmt: ContinueStmt): Ast = { - val node = NewControlStructure() - .controlStructureType(ControlStructureTypes.CONTINUE) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(stmt.toString) - Ast(node) - } - - private def getForCode(stmt: ForStmt): String = { - val init = stmt.getInitialization.asScala.map(_.toString).mkString(", ") - val compare = stmt.getCompare.toScala.map(_.toString) - val update = stmt.getUpdate.asScala.map(_.toString).mkString(", ") - s"for ($init; $compare; $update)" - } - def astForFor(stmt: ForStmt): Ast = { - val forNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.FOR) - .code(getForCode(stmt)) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - val initAsts = - stmt.getInitialization.asScala.flatMap(astsForExpression(_, expectedType = ExpectedType.empty)) - - val compareAsts = stmt.getCompare.toScala.toList.flatMap { - astsForExpression(_, ExpectedType.Boolean) - } - - val updateAsts = stmt.getUpdate.asScala.toList.flatMap { - astsForExpression(_, ExpectedType.empty) - } - - val stmtAsts = - astsForStatement(stmt.getBody) - - val ast = Ast(forNode) - .withChildren(initAsts) - .withChildren(compareAsts) - .withChildren(updateAsts) - .withChildren(stmtAsts) - - compareAsts.flatMap(_.root) match { - case c :: Nil => - ast.withConditionEdge(forNode, c) - case _ => ast - } - } - - private def iterableAssignAstsForNativeForEach( - iterableExpression: Expression, - iterableType: Option[String] - ): (NodeTypeInfo, Seq[Ast]) = { - val lineNo = line(iterableExpression) - val expectedType = ExpectedType(iterableType) - - val iterableAst = astsForExpression(iterableExpression, expectedType = expectedType) match { - case Nil => - logger.debug(s"Could not create AST for iterable expr $iterableExpression: $filename:l$lineNo") - Ast() - case iterableAstHead :: Nil => iterableAstHead - case iterableAsts => - logger.debug( - s"Found multiple ASTS for iterable expr $iterableExpression: $filename:l$lineNo\nDropping all but the first!" + + private def astForSwitchStatement(stmt: SwitchStmt): Ast = + val switchNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.SWITCH) + .code(s"switch(${stmt.getSelector.toString})") + + val selectorAsts = astsForExpression(stmt.getSelector, ExpectedType.empty) + val selectorNode = selectorAsts.head.root.get + + val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry) + + val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts) + + Ast(switchNode) + .withChildren(selectorAsts) + .withChild(switchBodyAst) + .withConditionEdge(switchNode, selectorNode) + + private def astForSynchronizedStatement(stmt: SynchronizedStmt): Ast = + val parentNode = + NewBlock() + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val modifier = Ast(newModifierNode("SYNCHRONIZED")) + + val exprAsts = astsForExpression(stmt.getExpression, ExpectedType.empty) + val bodyAst = astForBlockStatement(stmt.getBody) + + Ast(parentNode) + .withChild(modifier) + .withChildren(exprAsts) + .withChild(bodyAst) + + private def astsForSwitchCases(entry: SwitchEntry): Seq[Ast] = + entry.getLabels.asScala.toList match + case Nil => + val target = NewJumpTarget() + .name("default") + .code("default") + Seq(Ast(target)) + + case labels => + labels.flatMap { label => + val jumpTarget = NewJumpTarget() + .name("case") + .code(label.toString) + val labelAsts = astsForExpression(label, ExpectedType.empty).toList + + Ast(jumpTarget) :: labelAsts + } + + private def astForSwitchEntry(entry: SwitchEntry): Seq[Ast] = + val labelAsts = astsForSwitchCases(entry) + + val statementAsts = entry.getStatements.asScala.flatMap(astsForStatement) + + labelAsts ++ statementAsts + + private def astForAssertStatement(stmt: AssertStmt): Ast = + val callNode = NewCall() + .name("assert") + .methodFullName("assert") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(stmt.toString) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val args = astsForExpression(stmt.getCheck, ExpectedType.Boolean) + callAst(callNode, args) + + private def astForBlockStatement( + stmt: BlockStmt, + codeStr: String = "", + prefixAsts: Seq[Ast] = Seq.empty + ): Ast = + + val block = NewBlock() + .code(codeStr) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + scope.pushBlockScope() + + val stmtAsts = stmt.getStatements.asScala.flatMap(astsForStatement) + + scope.popScope() + Ast(block) + .withChildren(prefixAsts) + .withChildren(stmtAsts) + end astForBlockStatement + + private def astForReturnNode(ret: ReturnStmt): Ast = + val returnNode = NewReturn() + .lineNumber(line(ret)) + .columnNumber(column(ret)) + .code(ret.toString) + if ret.getExpression.isPresent then + val expectedType = scope.enclosingMethodReturnType.getOrElse(ExpectedType.empty) + val exprAsts = astsForExpression(ret.getExpression.get(), expectedType) + returnAst(returnNode, exprAsts) + else + Ast(returnNode) + + private def astForUnaryExpr(expr: UnaryExpr, expectedType: ExpectedType): Ast = + val operatorName = expr.getOperator match + case UnaryExpr.Operator.LOGICAL_COMPLEMENT => Operators.logicalNot + case UnaryExpr.Operator.POSTFIX_DECREMENT => Operators.postDecrement + case UnaryExpr.Operator.POSTFIX_INCREMENT => Operators.postIncrement + case UnaryExpr.Operator.PREFIX_DECREMENT => Operators.preDecrement + case UnaryExpr.Operator.PREFIX_INCREMENT => Operators.preIncrement + case UnaryExpr.Operator.BITWISE_COMPLEMENT => Operators.not + case UnaryExpr.Operator.PLUS => Operators.plus + case UnaryExpr.Operator.MINUS => Operators.minus + + val argsAsts = astsForExpression(expr.getExpression, expectedType) + + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(argsAsts.headOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + operatorName, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) ) - iterableAsts.head - } - - val iterableName = nextIterableName() - val iterableLocalNode = localNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) - val iterableLocalAst = Ast(iterableLocalNode) - - val iterableAssignNode = - newOperatorCallNode(Operators.assignment, code = "", line = lineNo, typeFullName = iterableType) - val iterableAssignIdentifier = - identifierNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) - val iterableAssignArgs = List(Ast(iterableAssignIdentifier), iterableAst) - val iterableAssignAst = - callAst(iterableAssignNode, iterableAssignArgs) - .withRefEdge(iterableAssignIdentifier, iterableLocalNode) - - ( - NodeTypeInfo(iterableLocalNode, iterableLocalNode.name, Some(iterableLocalNode.typeFullName)), - List(iterableLocalAst, iterableAssignAst) - ) - } - - private def nativeForEachIdxLocalNode(lineNo: Option[Integer]): NewLocal = { - val idxName = nextIndexName() - val typeFullName = TypeConstants.Int - val idxLocal = - NewLocal() - .name(idxName) - .typeFullName(typeFullName) - .code(idxName) - .lineNumber(lineNo) - scope.addLocal(idxLocal) - idxLocal - } - - private def nativeForEachIdxInitializerAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = { - val idxName = idxLocal.name - val idxInitializerCallNode = newOperatorCallNode( - Operators.assignment, - code = s"int $idxName = 0", - line = lineNo, - typeFullName = Some(TypeConstants.Int) - ) - val idxIdentifierArg = newIdentifierNode(idxName, idxLocal.typeFullName) - val zeroLiteral = - NewLiteral() - .code("0") - .typeFullName(TypeConstants.Int) - .lineNumber(lineNo) - val idxInitializerArgAsts = List(Ast(idxIdentifierArg), Ast(zeroLiteral)) - callAst(idxInitializerCallNode, idxInitializerArgAsts) - .withRefEdge(idxIdentifierArg, idxLocal) - } - - private def nativeForEachCompareAst( - lineNo: Option[Integer], - iterableSource: NodeTypeInfo, - idxLocal: NewLocal - ): Ast = { - val idxName = idxLocal.name - - val compareNode = newOperatorCallNode( - Operators.lessThan, - code = s"$idxName < ${iterableSource.name}.length", - typeFullName = Some(TypeConstants.Boolean), - line = lineNo - ) - val comparisonIdxIdentifier = newIdentifierNode(idxName, idxLocal.typeFullName) - val comparisonFieldAccess = newOperatorCallNode( - Operators.fieldAccess, - code = s"${iterableSource.name}.length", - typeFullName = Some(TypeConstants.Int), - line = lineNo - ) - val fieldAccessIdentifier = newIdentifierNode(iterableSource.name, iterableSource.typeFullName.getOrElse("ANY")) - val fieldAccessFieldIdentifier = newFieldIdentifierNode("length", lineNo) - val fieldAccessArgs = List(fieldAccessIdentifier, fieldAccessFieldIdentifier).map(Ast(_)) - val fieldAccessAst = callAst(comparisonFieldAccess, fieldAccessArgs) - val compareArgs = List(Ast(comparisonIdxIdentifier), fieldAccessAst) - - // TODO: This is a workaround for a crash when looping over statically imported members. Handle those properly. - val iterableSourceNode = localParamOrMemberFromNode(iterableSource) - - callAst(compareNode, compareArgs) - .withRefEdge(comparisonIdxIdentifier, idxLocal) - .withRefEdges(fieldAccessIdentifier, iterableSourceNode.toList) - } - - private def nativeForEachIncrementAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = { - val incrementNode = newOperatorCallNode( - Operators.postIncrement, - code = s"${idxLocal.name}++", - typeFullName = Some(TypeConstants.Int), - line = lineNo - ) - val incrementArg = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) - val incrementArgAst = Ast(incrementArg) - callAst(incrementNode, List(incrementArgAst)) - .withRefEdge(incrementArg, idxLocal) - } - - private def variableLocalForForEachBody(stmt: ForEachStmt): NewLocal = { - val lineNo = line(stmt) - // Create item local - val maybeVariable = stmt.getVariable.getVariables.asScala.toList match { - case Nil => - logger.debug(s"ForEach statement has empty variable list: $filename$lineNo") - None - case variable :: Nil => Some(variable) - case variable :: _ => - logger.debug(s"ForEach statement defines multiple variables. Dropping all but the first: $filename$lineNo") - Some(variable) - } - - val partialLocalNode = NewLocal().lineNumber(lineNo) - - maybeVariable match { - case Some(variable) => - val name = variable.getNameAsString - val typeFullName = typeInfoCalc.fullName(variable.getType).getOrElse("ANY") - val localNode = partialLocalNode - .name(name) - .code(variable.getNameAsString) - .typeFullName(typeFullName) - - scope.addLocal(localNode) - localNode - - case None => - // Returning partialLocalNode here is fine since getting to this case means everything is broken anyways :) - partialLocalNode - } - } - - private def localParamOrMemberFromNode(nodeTypeInfo: NodeTypeInfo): Option[NewNode] = { - nodeTypeInfo.node match { - case localNode: NewLocal => Some(localNode) - case memberNode: NewMember => Some(memberNode) - case parameterNode: NewMethodParameterIn => Some(parameterNode) - case _ => None - } - } - private def variableAssignForNativeForEachBody( - variableLocal: NewLocal, - idxLocal: NewLocal, - iterable: NodeTypeInfo - ): Ast = { - // Everything will be on the same line as the `for` statement, but this is the most useful - // solution for debugging. - val lineNo = variableLocal.lineNumber - val varAssignNode = - newOperatorCallNode(Operators.assignment, PropertyDefaults.Code, Some(variableLocal.typeFullName), lineNo) - - val targetNode = newIdentifierNode(variableLocal.name, variableLocal.typeFullName) - - val indexAccessTypeFullName = iterable.typeFullName.map(_.replaceAll(raw"\[]", "")) - val indexAccess = - newOperatorCallNode(Operators.indexAccess, PropertyDefaults.Code, indexAccessTypeFullName, lineNo) - - val indexAccessIdentifier = newIdentifierNode(iterable.name, iterable.typeFullName.getOrElse("ANY")) - val indexAccessIndex = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) - - val indexAccessArgsAsts = List(indexAccessIdentifier, indexAccessIndex).map(Ast(_)) - val indexAccessAst = callAst(indexAccess, indexAccessArgsAsts) - - val iterableSourceNode = localParamOrMemberFromNode(iterable) - - val assignArgsAsts = List(Ast(targetNode), indexAccessAst) - callAst(varAssignNode, assignArgsAsts) - .withRefEdge(targetNode, variableLocal) - .withRefEdges(indexAccessIdentifier, iterableSourceNode.toList) - .withRefEdge(indexAccessIndex, idxLocal) - } - - private def nativeForEachBodyAst(stmt: ForEachStmt, idxLocal: NewLocal, iterable: NodeTypeInfo): Ast = { - val variableLocal = variableLocalForForEachBody(stmt) - val variableLocalAst = Ast(variableLocal) - val variableAssignAst = variableAssignForNativeForEachBody(variableLocal, idxLocal, iterable) - - stmt.getBody match { - case block: BlockStmt => - astForBlockStatement(block, prefixAsts = List(variableLocalAst, variableAssignAst)) - case statement => - val stmtAsts = astsForStatement(statement) - val blockNode = NewBlock().lineNumber(variableLocal.lineNumber) - Ast(blockNode) - .withChild(variableLocalAst) - .withChild(variableAssignAst) - .withChildren(stmtAsts) - } - } - - private def astsForNativeForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = { - - // This is ugly, but for a case like `for (int x : new int[] { ... })` this creates a new LOCAL - // with the assignment `int[] $iterLocal0 = new int[] { ... }` before the FOR loop. - // TODO: Fix this - val (iterableSource: NodeTypeInfo, tempIterableInitAsts) = stmt.getIterable match { - case nameExpr: NameExpr => - scope.lookupVariable(nameExpr.getNameAsString).asNodeInfoOption match { - // If this is not the case, then the code is broken (iterable not in scope). - case Some(nodeTypeInfo) => (nodeTypeInfo, Nil) - case _ => iterableAssignAstsForNativeForEach(nameExpr, iterableType) + + callAst(callNode, argsAsts) + end astForUnaryExpr + + private def astForArrayAccessExpr(expr: ArrayAccessExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + val callNode = newOperatorCallNode( + Operators.indexAccess, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val arrayExpectedType = expectedType.copy(fullName = expectedType.fullName.map(_ ++ "[]")) + val nameAst = astsForExpression(expr.getName, arrayExpectedType) + val indexAst = astsForExpression(expr.getIndex, ExpectedType.Int) + val args = nameAst ++ indexAst + callAst(callNode, args) + + private def astForArrayCreationExpr(expr: ArrayCreationExpr, expectedType: ExpectedType): Ast = + val maybeInitializerAst = + expr.getInitializer.toScala.map(astForArrayInitializerExpr(_, expectedType)) + + maybeInitializerAst.flatMap(_.root) match + case Some(initializerRoot: NewCall) => initializerRoot.code(expr.toString) + case _ => // This should never happen + maybeInitializerAst.getOrElse { + val typeFullName = expressionReturnTypeFullName(expr).orElse( + expectedType.fullName + ).getOrElse(TypeConstants.Any) + val callNode = newOperatorCallNode( + Operators.alloc, + code = expr.toString, + typeFullName = Some(typeFullName) + ) + val levelAsts = expr.getLevels.asScala.flatMap { lvl => + lvl.getDimension.toScala match + case Some(dimension) => astsForExpression(dimension, ExpectedType.Int) + + case None => Seq.empty + }.toSeq + callAst(callNode, levelAsts) } - case iterableExpr => iterableAssignAstsForNativeForEach(iterableExpr, iterableType) - } - - val forNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.FOR) - - val lineNo = line(stmt) - - val idxLocal = nativeForEachIdxLocalNode(lineNo) - val idxInitializerAst = nativeForEachIdxInitializerAst(lineNo, idxLocal) - // TODO next: pass NodeTypeInfo around - val compareAst = nativeForEachCompareAst(lineNo, iterableSource, idxLocal) - val incrementAst = nativeForEachIncrementAst(lineNo, idxLocal) - val bodyAst = nativeForEachBodyAst(stmt, idxLocal, iterableSource) - - val forAst = Ast(forNode) - .withChild(Ast(idxLocal)) - .withChild(idxInitializerAst) - .withChild(compareAst) - .withChild(incrementAst) - .withChild(bodyAst) - .withConditionEdges(forNode, compareAst.root.toList) - - tempIterableInitAsts ++ Seq(forAst) - } - - private def iteratorLocalForForEach(lineNumber: Option[Integer]): NewLocal = { - val iteratorLocalName = nextIterableName() - NewLocal() - .name(iteratorLocalName) - .code(iteratorLocalName) - .typeFullName(TypeConstants.Iterator) - .lineNumber(lineNumber) - } - - private def iteratorAssignAstForForEach( - iterExpr: Expression, - iteratorLocalNode: NewLocal, - iterableType: Option[String], - lineNo: Option[Integer] - ): Ast = { - val iteratorAssignNode = - newOperatorCallNode(Operators.assignment, code = "", typeFullName = Some(TypeConstants.Iterator), line = lineNo) - val iteratorAssignIdentifier = - identifierNode(iterExpr, iteratorLocalNode.name, iteratorLocalNode.name, iteratorLocalNode.typeFullName) - - val iteratorCallNode = - newCallNode("iterator", iterableType, TypeConstants.Iterator, DispatchTypes.DYNAMIC_DISPATCH, lineNumber = lineNo) - - val actualIteratorAst = astsForExpression(iterExpr, expectedType = ExpectedType.empty).toList match { - case Nil => - logger.debug(s"Could not create receiver ast for iterator $iterExpr") - None - - case ast :: Nil => Some(ast) - - case ast :: _ => - logger.debug(s"Created multiple receiver asts for $iterExpr. Dropping all but the first.") - Some(ast) - } - - val iteratorCallAst = - callAst(iteratorCallNode, base = actualIteratorAst) - - callAst(iteratorAssignNode, List(Ast(iteratorAssignIdentifier), iteratorCallAst)) - .withRefEdge(iteratorAssignIdentifier, iteratorLocalNode) - } - - private def hasNextCallAstForForEach(iteratorLocalNode: NewLocal, lineNo: Option[Integer]): Ast = { - val iteratorHasNextCallNode = - newCallNode( - "hasNext", - Some(TypeConstants.Iterator), - TypeConstants.Boolean, - DispatchTypes.DYNAMIC_DISPATCH, - lineNumber = lineNo - ) - val iteratorHasNextCallReceiver = - newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) - - callAst(iteratorHasNextCallNode, base = Some(Ast(iteratorHasNextCallReceiver))) - .withRefEdge(iteratorHasNextCallReceiver, iteratorLocalNode) - } - - private def astForIterableForEachItemAssign(iteratorLocalNode: NewLocal, variableLocal: NewLocal): Ast = { - val lineNo = variableLocal.lineNumber - val forVariableType = variableLocal.typeFullName - val varLocalAssignNode = - newOperatorCallNode(Operators.assignment, PropertyDefaults.Code, Some(forVariableType), lineNo) - val varLocalAssignIdentifier = newIdentifierNode(variableLocal.name, variableLocal.typeFullName) - - val iterNextCallNode = - newCallNode( - "next", - Some(TypeConstants.Iterator), - TypeConstants.Object, - DispatchTypes.DYNAMIC_DISPATCH, - lineNumber = lineNo - ) - val iterNextCallReceiver = newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) - val iterNextCallAst = - callAst(iterNextCallNode, base = Some(Ast(iterNextCallReceiver))) - .withRefEdge(iterNextCallReceiver, iteratorLocalNode) - - callAst(varLocalAssignNode, List(Ast(varLocalAssignIdentifier), iterNextCallAst)) - .withRefEdge(varLocalAssignIdentifier, variableLocal) - } - - private def astForIterableForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = { - val lineNo = line(stmt) - - val iteratorLocalNode = iteratorLocalForForEach(lineNo) - val iteratorAssignAst = - iteratorAssignAstForForEach(stmt.getIterable, iteratorLocalNode, iterableType, lineNo) - val iteratorHasNextCallAst = hasNextCallAstForForEach(iteratorLocalNode, lineNo) - val variableLocal = variableLocalForForEachBody(stmt) - val variableAssignAst = astForIterableForEachItemAssign(iteratorLocalNode, variableLocal) - - val bodyPrefixAsts = Seq(Ast(variableLocal), variableAssignAst) - val bodyAst = stmt.getBody match { - case block: BlockStmt => - astForBlockStatement(block, prefixAsts = bodyPrefixAsts) - - case bodyStmt => - val bodyBlockNode = NewBlock().lineNumber(lineNo) - val bodyStmtAsts = astsForStatement(bodyStmt) - Ast(bodyBlockNode) - .withChildren(bodyPrefixAsts) - .withChildren(bodyStmtAsts) - } - - val forNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.WHILE) - .code(ControlStructureTypes.FOR) - .lineNumber(lineNo) - .columnNumber(column(stmt)) - - val forAst = controlStructureAst(forNode, Some(iteratorHasNextCallAst), List(bodyAst)) - - Seq(Ast(iteratorLocalNode), iteratorAssignAst, forAst) - } - - private def astForForEach(stmt: ForEachStmt): Seq[Ast] = { - scope.pushBlockScope() - - val ast = expressionReturnTypeFullName(stmt.getIterable) match { - case Some(typeFullName) if typeFullName.endsWith("[]") => - astsForNativeForEach(stmt, Some(typeFullName)) - - case maybeType => - astForIterableForEach(stmt, maybeType) - } - - scope.popScope() - ast - } - - private def astForSwitchStatement(stmt: SwitchStmt): Ast = { - val switchNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.SWITCH) - .code(s"switch(${stmt.getSelector.toString})") - - val selectorAsts = astsForExpression(stmt.getSelector, ExpectedType.empty) - val selectorNode = selectorAsts.head.root.get - - val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry) - - val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts) - - Ast(switchNode) - .withChildren(selectorAsts) - .withChild(switchBodyAst) - .withConditionEdge(switchNode, selectorNode) - } - - private def astForSynchronizedStatement(stmt: SynchronizedStmt): Ast = { - val parentNode = - NewBlock() - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - val modifier = Ast(newModifierNode("SYNCHRONIZED")) - - val exprAsts = astsForExpression(stmt.getExpression, ExpectedType.empty) - val bodyAst = astForBlockStatement(stmt.getBody) - - Ast(parentNode) - .withChild(modifier) - .withChildren(exprAsts) - .withChild(bodyAst) - } - - private def astsForSwitchCases(entry: SwitchEntry): Seq[Ast] = { - entry.getLabels.asScala.toList match { - case Nil => - val target = NewJumpTarget() - .name("default") - .code("default") - Seq(Ast(target)) - - case labels => - labels.flatMap { label => - val jumpTarget = NewJumpTarget() - .name("case") - .code(label.toString) - val labelAsts = astsForExpression(label, ExpectedType.empty).toList - - Ast(jumpTarget) :: labelAsts + end astForArrayCreationExpr + + private def astForArrayInitializerExpr( + expr: ArrayInitializerExpr, + expectedType: ExpectedType + ): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + val callNode = newOperatorCallNode( + Operators.arrayInitializer, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val MAX_INITIALIZERS = 1000 + + val expectedValueType = expr.getValues.asScala.headOption.map { value => + // typeName and resolvedType may represent different types since typeName can fall + // back to known information or primitive types. While this certainly isn't ideal, + // it shouldn't cause issues since resolvedType is only used where the extra type + // information not available in typeName is necessary. + val typeName = expressionReturnTypeFullName(value) + val resolvedType = tryWithSafeStackOverflow(value.calculateResolvedType()).toOption + ExpectedType(typeName, resolvedType) } - } - } - - private def astForSwitchEntry(entry: SwitchEntry): Seq[Ast] = { - val labelAsts = astsForSwitchCases(entry) - - val statementAsts = entry.getStatements.asScala.flatMap(astsForStatement) - - labelAsts ++ statementAsts - } - - private def astForAssertStatement(stmt: AssertStmt): Ast = { - val callNode = NewCall() - .name("assert") - .methodFullName("assert") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(stmt.toString) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - val args = astsForExpression(stmt.getCheck, ExpectedType.Boolean) - callAst(callNode, args) - } - - private def astForBlockStatement( - stmt: BlockStmt, - codeStr: String = "", - prefixAsts: Seq[Ast] = Seq.empty - ): Ast = { - - val block = NewBlock() - .code(codeStr) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - scope.pushBlockScope() - - val stmtAsts = stmt.getStatements.asScala.flatMap(astsForStatement) - - scope.popScope() - Ast(block) - .withChildren(prefixAsts) - .withChildren(stmtAsts) - } - - private def astForReturnNode(ret: ReturnStmt): Ast = { - val returnNode = NewReturn() - .lineNumber(line(ret)) - .columnNumber(column(ret)) - .code(ret.toString) - if (ret.getExpression.isPresent) { - val expectedType = scope.enclosingMethodReturnType.getOrElse(ExpectedType.empty) - val exprAsts = astsForExpression(ret.getExpression.get(), expectedType) - returnAst(returnNode, exprAsts) - } else { - Ast(returnNode) - } - } - - private def astForUnaryExpr(expr: UnaryExpr, expectedType: ExpectedType): Ast = { - val operatorName = expr.getOperator match { - case UnaryExpr.Operator.LOGICAL_COMPLEMENT => Operators.logicalNot - case UnaryExpr.Operator.POSTFIX_DECREMENT => Operators.postDecrement - case UnaryExpr.Operator.POSTFIX_INCREMENT => Operators.postIncrement - case UnaryExpr.Operator.PREFIX_DECREMENT => Operators.preDecrement - case UnaryExpr.Operator.PREFIX_INCREMENT => Operators.preIncrement - case UnaryExpr.Operator.BITWISE_COMPLEMENT => Operators.not - case UnaryExpr.Operator.PLUS => Operators.plus - case UnaryExpr.Operator.MINUS => Operators.minus - } - - val argsAsts = astsForExpression(expr.getExpression, expectedType) - - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(argsAsts.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = newOperatorCallNode( - operatorName, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - callAst(callNode, argsAsts) - } - - private def astForArrayAccessExpr(expr: ArrayAccessExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - val callNode = newOperatorCallNode( - Operators.indexAccess, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - val arrayExpectedType = expectedType.copy(fullName = expectedType.fullName.map(_ ++ "[]")) - val nameAst = astsForExpression(expr.getName, arrayExpectedType) - val indexAst = astsForExpression(expr.getIndex, ExpectedType.Int) - val args = nameAst ++ indexAst - callAst(callNode, args) - } - - private def astForArrayCreationExpr(expr: ArrayCreationExpr, expectedType: ExpectedType): Ast = { - val maybeInitializerAst = expr.getInitializer.toScala.map(astForArrayInitializerExpr(_, expectedType)) - - maybeInitializerAst.flatMap(_.root) match { - case Some(initializerRoot: NewCall) => initializerRoot.code(expr.toString) - case _ => // This should never happen - } - - maybeInitializerAst.getOrElse { - val typeFullName = expressionReturnTypeFullName(expr).orElse(expectedType.fullName).getOrElse(TypeConstants.Any) - val callNode = newOperatorCallNode(Operators.alloc, code = expr.toString, typeFullName = Some(typeFullName)) - val levelAsts = expr.getLevels.asScala.flatMap { lvl => - lvl.getDimension.toScala match { - case Some(dimension) => astsForExpression(dimension, ExpectedType.Int) - - case None => Seq.empty + val args = expr.getValues.asScala + .slice(0, MAX_INITIALIZERS) + .flatMap(astsForExpression(_, expectedValueType.getOrElse(ExpectedType.empty))) + .toSeq + + val ast = callAst(callNode, args) + + if expr.getValues.size() > MAX_INITIALIZERS then + val placeholder = NewLiteral() + .typeFullName(TypeConstants.Any) + .code("") + .lineNumber(line(expr)) + .columnNumber(column(expr)) + ast.withChild(Ast(placeholder)).withArgEdge(callNode, placeholder) + else + ast + end astForArrayInitializerExpr + + def astForBinaryExpr(expr: BinaryExpr, expectedType: ExpectedType): Ast = + val operatorName = expr.getOperator match + case BinaryExpr.Operator.OR => Operators.logicalOr + case BinaryExpr.Operator.AND => Operators.logicalAnd + case BinaryExpr.Operator.BINARY_OR => Operators.or + case BinaryExpr.Operator.BINARY_AND => Operators.and + case BinaryExpr.Operator.DIVIDE => Operators.division + case BinaryExpr.Operator.EQUALS => Operators.equals + case BinaryExpr.Operator.GREATER => Operators.greaterThan + case BinaryExpr.Operator.GREATER_EQUALS => Operators.greaterEqualsThan + case BinaryExpr.Operator.LESS => Operators.lessThan + case BinaryExpr.Operator.LESS_EQUALS => Operators.lessEqualsThan + case BinaryExpr.Operator.LEFT_SHIFT => Operators.shiftLeft + case BinaryExpr.Operator.SIGNED_RIGHT_SHIFT => Operators.logicalShiftRight + case BinaryExpr.Operator.UNSIGNED_RIGHT_SHIFT => Operators.arithmeticShiftRight + case BinaryExpr.Operator.XOR => Operators.xor + case BinaryExpr.Operator.NOT_EQUALS => Operators.notEquals + case BinaryExpr.Operator.PLUS => Operators.addition + case BinaryExpr.Operator.MINUS => Operators.subtraction + case BinaryExpr.Operator.MULTIPLY => Operators.multiplication + case BinaryExpr.Operator.REMAINDER => Operators.modulo + + val args = + astsForExpression(expr.getLeft, expectedType) ++ astsForExpression( + expr.getRight, + expectedType + ) + + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(args.headOption.flatMap(_.rootType)) + .orElse(args.lastOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + operatorName, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + callAst(callNode, args) + end astForBinaryExpr + + private def astForCastExpr(expr: CastExpr, expectedType: ExpectedType): Ast = + val typeFullName = + typeInfoCalc + .fullName(expr.getType) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + Operators.cast, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val typeNode = NewTypeRef() + .code(expr.getType.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + .typeFullName(typeFullName) + val typeAst = Ast(typeNode) + + val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) + + callAst(callNode, Seq(typeAst) ++ exprAst) + end astForCastExpr + + private def astsForAssignExpr(expr: AssignExpr, expectedExprType: ExpectedType): Seq[Ast] = + val operatorName = expr.getOperator match + case Operator.ASSIGN => Operators.assignment + case Operator.PLUS => Operators.assignmentPlus + case Operator.MINUS => Operators.assignmentMinus + case Operator.MULTIPLY => Operators.assignmentMultiplication + case Operator.DIVIDE => Operators.assignmentDivision + case Operator.BINARY_AND => Operators.assignmentAnd + case Operator.BINARY_OR => Operators.assignmentOr + case Operator.XOR => Operators.assignmentXor + case Operator.REMAINDER => Operators.assignmentModulo + case Operator.LEFT_SHIFT => Operators.assignmentShiftLeft + case Operator.SIGNED_RIGHT_SHIFT => Operators.assignmentArithmeticShiftRight + case Operator.UNSIGNED_RIGHT_SHIFT => Operators.assignmentLogicalShiftRight + + val maybeResolvedType = Try(expr.getTarget.calculateResolvedType()).toOption + val expectedType = maybeResolvedType + .map { resolvedType => + ExpectedType(typeInfoCalc.fullName(resolvedType), Some(resolvedType)) + } + .getOrElse(expectedExprType) // resolved target type should be more accurate + val targetAst = astsForExpression(expr.getTarget, expectedType) + val argsAsts = astsForExpression(expr.getValue, expectedType) + val valueType = argsAsts.headOption.flatMap(_.rootType) + + val typeFullName = + targetAst.headOption + .flatMap(_.rootType) + .orElse(valueType) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val code = + s"${targetAst.rootCodeOrEmpty} ${expr.getOperator.asString} ${argsAsts.rootCodeOrEmpty}" + + val callNode = + newOperatorCallNode(operatorName, code, Some(typeFullName), line(expr), column(expr)) + + if partialConstructorQueue.isEmpty then + val assignAst = callAst(callNode, targetAst ++ argsAsts) + Seq(assignAst) + else + if partialConstructorQueue.size > 1 then + logger.debug( + "BUG: Received multiple partial constructors from assignment. Dropping all but the first." + ) + val partialConstructor = partialConstructorQueue.head + partialConstructorQueue.clear() + + targetAst.flatMap(_.root).toList match + case List(identifier: NewIdentifier) => + // In this case we have a simple assign. No block needed. + // e.g. Foo f = new Foo(); + val initAst = + completeInitForConstructor(partialConstructor, Ast(identifier.copy)) + Seq(callAst(callNode, targetAst ++ argsAsts), initAst) + + case _ => + // In this case the left hand side is more complex than an identifier, so + // we need to contain the constructor in a block. + // e.g. items[10] = new Foo(); + val valueAst = partialConstructor.blockAst + Seq(callAst(callNode, targetAst ++ Seq(valueAst))) + end if + end astsForAssignExpr + + private def localsForVarDecl(varDecl: VariableDeclarationExpr): List[NewLocal] = + varDecl.getVariables.asScala.map { variable => + val name = variable.getName.toString + val typeFullName = + tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten + .orElse(scope.lookupType(variable.getTypeAsString)) + .getOrElse(TypeConstants.Any) + val code = s"${variable.getType} $name" + NewLocal() + .name(name) + .code(code) + .typeFullName(typeFullName) + .lineNumber(line(varDecl)) + .columnNumber(column(varDecl)) + }.toList + + private def copyAstForVarDeclInit(targetAst: Ast): Ast = + targetAst.root match + case Some(identifier: NewIdentifier) => Ast(identifier.copy) + + case Some(fieldAccess: NewCall) if fieldAccess.name == Operators.fieldAccess => + val maybeIdentifier = targetAst.nodes.collectFirst { + case node if node.isInstanceOf[NewIdentifier] => node + } + val maybeField = targetAst.nodes.collectFirst { + case node if node.isInstanceOf[NewFieldIdentifier] => node + } + + (maybeIdentifier, maybeField) match + case (Some(identifier), Some(fieldIdentifier)) => + val args = List(identifier, fieldIdentifier).map(node => Ast(node.copy)) + callAst(fieldAccess.copy, args) + + case _ => + logger.debug( + s"Attempting to copy field access without required children: ${fieldAccess.code}" + ) + Ast() + + case Some(root) => + logger.debug(s"Attempting to copy unhandled root type for var decl init: $root") + Ast() + + case None => + Ast() + + private def assignmentsForVarDecl( + variables: Iterable[VariableDeclarator], + lineNumber: Option[Integer], + columnNumber: Option[Integer] + ): Seq[Ast] = + val variablesWithInitializers = + variables.filter(_.getInitializer.toScala.isDefined) + val assignments = variablesWithInitializers.flatMap { variable => + val name = variable.getName.toString + val initializer = variable.getInitializer.toScala.get // Won't crash because of filter + val initializerTypeFullName = + variable.getInitializer.toScala.flatMap(expressionReturnTypeFullName) + val javaParserVarType = variable.getTypeAsString + val variableTypeFullName = + tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten + // TODO: Surely the variable being declared can't already be in scope? + .orElse(scope.lookupVariable(name).typeFullName) + .orElse(scope.lookupType(javaParserVarType)) + + val typeFullName = + variableTypeFullName.orElse(initializerTypeFullName) + + // Need the actual resolvedType here for when the RHS is a lambda expression. + val resolvedExpectedType = + tryWithSafeStackOverflow(symbolSolver.toResolvedType( + variable.getType, + classOf[ResolvedType] + )).toOption + val initializerAsts = + astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType)) + + val typeName = typeFullName + .map(TypeNodePass.fullToShortName) + .getOrElse(s"${Defines.UnresolvedNamespace}.${variable.getTypeAsString}") + val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}" + + val callNode = newOperatorCallNode( + Operators.assignment, + code, + typeFullName, + lineNumber, + columnNumber + ) + + val targetAst = scope.lookupVariable(name).getVariable() match + // TODO: This definitely feels like a bug. Why is the found member not being used for anything? + case Some(ScopeMember(_, false)) => + val thisType = scope.enclosingTypeDeclFullName + fieldAccessAst( + NameConstants.This, + thisType, + name, + typeFullName, + line(variable), + column(variable) + ) + + case maybeCorrespNode => + val identifier = identifierNode( + variable, + name, + name, + typeFullName.getOrElse(TypeConstants.Any) + ) + Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) + + // Since all partial constructors will be dealt with here, don't pass them up. + val declAst = callAst(callNode, Seq(targetAst) ++ initializerAsts) + + val constructorAsts = partialConstructorQueue.map(completeInitForConstructor( + _, + copyAstForVarDeclInit(targetAst) + )) + partialConstructorQueue.clear() + + Seq(declAst) ++ constructorAsts } - }.toSeq - callAst(callNode, levelAsts) - } - } - - private def astForArrayInitializerExpr(expr: ArrayInitializerExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - val callNode = newOperatorCallNode( - Operators.arrayInitializer, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - val MAX_INITIALIZERS = 1000 - - val expectedValueType = expr.getValues.asScala.headOption.map { value => - // typeName and resolvedType may represent different types since typeName can fall - // back to known information or primitive types. While this certainly isn't ideal, - // it shouldn't cause issues since resolvedType is only used where the extra type - // information not available in typeName is necessary. - val typeName = expressionReturnTypeFullName(value) - val resolvedType = tryWithSafeStackOverflow(value.calculateResolvedType()).toOption - ExpectedType(typeName, resolvedType) - } - val args = expr.getValues.asScala - .slice(0, MAX_INITIALIZERS) - .flatMap(astsForExpression(_, expectedValueType.getOrElse(ExpectedType.empty))) - .toSeq - - val ast = callAst(callNode, args) - - if (expr.getValues.size() > MAX_INITIALIZERS) { - val placeholder = NewLiteral() - .typeFullName(TypeConstants.Any) - .code("") - .lineNumber(line(expr)) - .columnNumber(column(expr)) - ast.withChild(Ast(placeholder)).withArgEdge(callNode, placeholder) - } else { - ast - } - } - - def astForBinaryExpr(expr: BinaryExpr, expectedType: ExpectedType): Ast = { - val operatorName = expr.getOperator match { - case BinaryExpr.Operator.OR => Operators.logicalOr - case BinaryExpr.Operator.AND => Operators.logicalAnd - case BinaryExpr.Operator.BINARY_OR => Operators.or - case BinaryExpr.Operator.BINARY_AND => Operators.and - case BinaryExpr.Operator.DIVIDE => Operators.division - case BinaryExpr.Operator.EQUALS => Operators.equals - case BinaryExpr.Operator.GREATER => Operators.greaterThan - case BinaryExpr.Operator.GREATER_EQUALS => Operators.greaterEqualsThan - case BinaryExpr.Operator.LESS => Operators.lessThan - case BinaryExpr.Operator.LESS_EQUALS => Operators.lessEqualsThan - case BinaryExpr.Operator.LEFT_SHIFT => Operators.shiftLeft - case BinaryExpr.Operator.SIGNED_RIGHT_SHIFT => Operators.logicalShiftRight - case BinaryExpr.Operator.UNSIGNED_RIGHT_SHIFT => Operators.arithmeticShiftRight - case BinaryExpr.Operator.XOR => Operators.xor - case BinaryExpr.Operator.NOT_EQUALS => Operators.notEquals - case BinaryExpr.Operator.PLUS => Operators.addition - case BinaryExpr.Operator.MINUS => Operators.subtraction - case BinaryExpr.Operator.MULTIPLY => Operators.multiplication - case BinaryExpr.Operator.REMAINDER => Operators.modulo - } - - val args = - astsForExpression(expr.getLeft, expectedType) ++ astsForExpression(expr.getRight, expectedType) - - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(args.headOption.flatMap(_.rootType)) - .orElse(args.lastOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = newOperatorCallNode( - operatorName, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - callAst(callNode, args) - } - - private def astForCastExpr(expr: CastExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - typeInfoCalc - .fullName(expr.getType) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = newOperatorCallNode( - Operators.cast, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - val typeNode = NewTypeRef() - .code(expr.getType.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr)) - .typeFullName(typeFullName) - val typeAst = Ast(typeNode) - - val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) - - callAst(callNode, Seq(typeAst) ++ exprAst) - } - - private def astsForAssignExpr(expr: AssignExpr, expectedExprType: ExpectedType): Seq[Ast] = { - val operatorName = expr.getOperator match { - case Operator.ASSIGN => Operators.assignment - case Operator.PLUS => Operators.assignmentPlus - case Operator.MINUS => Operators.assignmentMinus - case Operator.MULTIPLY => Operators.assignmentMultiplication - case Operator.DIVIDE => Operators.assignmentDivision - case Operator.BINARY_AND => Operators.assignmentAnd - case Operator.BINARY_OR => Operators.assignmentOr - case Operator.XOR => Operators.assignmentXor - case Operator.REMAINDER => Operators.assignmentModulo - case Operator.LEFT_SHIFT => Operators.assignmentShiftLeft - case Operator.SIGNED_RIGHT_SHIFT => Operators.assignmentArithmeticShiftRight - case Operator.UNSIGNED_RIGHT_SHIFT => Operators.assignmentLogicalShiftRight - } - - val maybeResolvedType = Try(expr.getTarget.calculateResolvedType()).toOption - val expectedType = maybeResolvedType - .map { resolvedType => - ExpectedType(typeInfoCalc.fullName(resolvedType), Some(resolvedType)) - } - .getOrElse(expectedExprType) // resolved target type should be more accurate - val targetAst = astsForExpression(expr.getTarget, expectedType) - val argsAsts = astsForExpression(expr.getValue, expectedType) - val valueType = argsAsts.headOption.flatMap(_.rootType) - - val typeFullName = - targetAst.headOption - .flatMap(_.rootType) - .orElse(valueType) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val code = s"${targetAst.rootCodeOrEmpty} ${expr.getOperator.asString} ${argsAsts.rootCodeOrEmpty}" - - val callNode = newOperatorCallNode(operatorName, code, Some(typeFullName), line(expr), column(expr)) - - if (partialConstructorQueue.isEmpty) { - val assignAst = callAst(callNode, targetAst ++ argsAsts) - Seq(assignAst) - } else { - if (partialConstructorQueue.size > 1) { - logger.debug("BUG: Received multiple partial constructors from assignment. Dropping all but the first.") - } - val partialConstructor = partialConstructorQueue.head - partialConstructorQueue.clear() - - targetAst.flatMap(_.root).toList match { - case List(identifier: NewIdentifier) => - // In this case we have a simple assign. No block needed. - // e.g. Foo f = new Foo(); - val initAst = completeInitForConstructor(partialConstructor, Ast(identifier.copy)) - Seq(callAst(callNode, targetAst ++ argsAsts), initAst) - - case _ => - // In this case the left hand side is more complex than an identifier, so - // we need to contain the constructor in a block. - // e.g. items[10] = new Foo(); - val valueAst = partialConstructor.blockAst - Seq(callAst(callNode, targetAst ++ Seq(valueAst))) - } - } - } - - private def localsForVarDecl(varDecl: VariableDeclarationExpr): List[NewLocal] = { - varDecl.getVariables.asScala.map { variable => - val name = variable.getName.toString - val typeFullName = - tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten - .orElse(scope.lookupType(variable.getTypeAsString)) - .getOrElse(TypeConstants.Any) - val code = s"${variable.getType} $name" - NewLocal() - .name(name) - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(varDecl)) - .columnNumber(column(varDecl)) - }.toList - } - - private def copyAstForVarDeclInit(targetAst: Ast): Ast = { - targetAst.root match { - case Some(identifier: NewIdentifier) => Ast(identifier.copy) - - case Some(fieldAccess: NewCall) if fieldAccess.name == Operators.fieldAccess => - val maybeIdentifier = targetAst.nodes.collectFirst { case node if node.isInstanceOf[NewIdentifier] => node } - val maybeField = targetAst.nodes.collectFirst { case node if node.isInstanceOf[NewFieldIdentifier] => node } - - (maybeIdentifier, maybeField) match { - case (Some(identifier), Some(fieldIdentifier)) => - val args = List(identifier, fieldIdentifier).map(node => Ast(node.copy)) - callAst(fieldAccess.copy, args) - - case _ => - logger.debug(s"Attempting to copy field access without required children: ${fieldAccess.code}") - Ast() + + assignments.toList + end assignmentsForVarDecl + + private def completeInitForConstructor( + partialConstructor: PartialConstructor, + targetAst: Ast + ): Ast = + val initNode = partialConstructor.initNode + val args = partialConstructor.initArgs + + targetAst.root match + case Some(identifier: NewIdentifier) => + scope.lookupVariable(identifier.name).variableNode.foreach { variableNode => + diffGraph.addEdge(identifier, variableNode, EdgeTypes.REF) + } + + case _ => // Nothing to do in this case + callAst(initNode, args.toList, Some(targetAst)) + + private def astsForVariableDecl(varDecl: VariableDeclarationExpr): Seq[Ast] = + val locals = localsForVarDecl(varDecl) + val localAsts = locals.map { Ast(_) } + + locals.foreach { local => + scope.addLocal(local) } - case Some(root) => - logger.debug(s"Attempting to copy unhandled root type for var decl init: $root") - Ast() - - case None => - Ast() - } - } - - private def assignmentsForVarDecl( - variables: Iterable[VariableDeclarator], - lineNumber: Option[Integer], - columnNumber: Option[Integer] - ): Seq[Ast] = { - val variablesWithInitializers = - variables.filter(_.getInitializer.toScala.isDefined) - val assignments = variablesWithInitializers.flatMap { variable => - val name = variable.getName.toString - val initializer = variable.getInitializer.toScala.get // Won't crash because of filter - val initializerTypeFullName = variable.getInitializer.toScala.flatMap(expressionReturnTypeFullName) - val javaParserVarType = variable.getTypeAsString - val variableTypeFullName = - tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten - // TODO: Surely the variable being declared can't already be in scope? - .orElse(scope.lookupVariable(name).typeFullName) - .orElse(scope.lookupType(javaParserVarType)) - - val typeFullName = - variableTypeFullName.orElse(initializerTypeFullName) - - // Need the actual resolvedType here for when the RHS is a lambda expression. - val resolvedExpectedType = - tryWithSafeStackOverflow(symbolSolver.toResolvedType(variable.getType, classOf[ResolvedType])).toOption - val initializerAsts = astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType)) - - val typeName = typeFullName - .map(TypeNodePass.fullToShortName) - .getOrElse(s"${Defines.UnresolvedNamespace}.${variable.getTypeAsString}") - val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}" - - val callNode = newOperatorCallNode(Operators.assignment, code, typeFullName, lineNumber, columnNumber) - - val targetAst = scope.lookupVariable(name).getVariable() match { - // TODO: This definitely feels like a bug. Why is the found member not being used for anything? - case Some(ScopeMember(_, false)) => - val thisType = scope.enclosingTypeDeclFullName - fieldAccessAst(NameConstants.This, thisType, name, typeFullName, line(variable), column(variable)) - - case maybeCorrespNode => - val identifier = identifierNode(variable, name, name, typeFullName.getOrElse(TypeConstants.Any)) - Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) - } - - // Since all partial constructors will be dealt with here, don't pass them up. - val declAst = callAst(callNode, Seq(targetAst) ++ initializerAsts) - - val constructorAsts = partialConstructorQueue.map(completeInitForConstructor(_, copyAstForVarDeclInit(targetAst))) - partialConstructorQueue.clear() - - Seq(declAst) ++ constructorAsts - } - - assignments.toList - } - - private def completeInitForConstructor(partialConstructor: PartialConstructor, targetAst: Ast): Ast = { - val initNode = partialConstructor.initNode - val args = partialConstructor.initArgs - - targetAst.root match { - case Some(identifier: NewIdentifier) => - scope.lookupVariable(identifier.name).variableNode.foreach { variableNode => - diffGraph.addEdge(identifier, variableNode, EdgeTypes.REF) + val assignments = + assignmentsForVarDecl(varDecl.getVariables.asScala, line(varDecl), column(varDecl)) + + localAsts ++ assignments + + private def astForClassExpr(expr: ClassExpr): Ast = + val someTypeFullName = Some(TypeConstants.Class) + val callNode = newOperatorCallNode( + Operators.fieldAccess, + expr.toString, + someTypeFullName, + line(expr), + column(expr) + ) + + val identifierType = typeInfoCalc.fullName(expr.getType) + val identifier = identifierNode( + expr, + expr.getTypeAsString, + expr.getTypeAsString, + identifierType.getOrElse("ANY") + ) + val idAst = Ast(identifier) + + val fieldIdentifier = NewFieldIdentifier() + .canonicalName("class") + .code("class") + .lineNumber(line(expr)) + .columnNumber(column(expr)) + val fieldIdAst = Ast(fieldIdentifier) + + callAst(callNode, Seq(idAst, fieldIdAst)) + end astForClassExpr + + private def astForConditionalExpr(expr: ConditionalExpr, expectedType: ExpectedType): Ast = + val condAst = astsForExpression(expr.getCondition, ExpectedType.Boolean) + val thenAst = astsForExpression(expr.getThenExpr, expectedType) + val elseAst = astsForExpression(expr.getElseExpr, expectedType) + + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(thenAst.headOption.flatMap(_.rootType)) + .orElse(elseAst.headOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = + newOperatorCallNode( + Operators.conditional, + expr.toString, + Some(typeFullName), + line(expr), + column(expr) + ) + + callAst(callNode, condAst ++ thenAst ++ elseAst) + end astForConditionalExpr + + private def astForEnclosedExpression(expr: EnclosedExpr, expectedType: ExpectedType): Seq[Ast] = + astsForExpression(expr.getInner, expectedType) + + private def astForFieldAccessExpr(expr: FieldAccessExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = + newOperatorCallNode( + Operators.fieldAccess, + expr.toString, + Some(typeFullName), + line(expr), + column(expr) + ) + + val fieldIdentifier = expr.getName + val identifierAsts = astsForExpression(expr.getScope, ExpectedType.empty) + val fieldIdentifierNode = NewFieldIdentifier() + .canonicalName(fieldIdentifier.toString) + .lineNumber(line(fieldIdentifier)) + .columnNumber(column(fieldIdentifier)) + .code(fieldIdentifier.toString) + val fieldIdAst = Ast(fieldIdentifierNode) + + callAst(callNode, identifierAsts ++ Seq(fieldIdAst)) + end astForFieldAccessExpr + + private def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = + val booleanTypeFullName = Some(TypeConstants.Boolean) + val callNode = + newOperatorCallNode( + Operators.instanceOf, + expr.toString, + booleanTypeFullName, + line(expr), + column(expr) + ) + + val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) + val typeFullName = typeInfoCalc.fullName(expr.getType).getOrElse(TypeConstants.Any) + val typeNode = + NewTypeRef() + .code(expr.getType.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr.getType)) + .typeFullName(typeFullName) + val typeAst = Ast(typeNode) + + callAst(callNode, exprAst ++ Seq(typeAst)) + end astForInstanceOfExpr + + private def fieldAccessAst( + identifierName: String, + identifierType: Option[String], + fieldIdentifierName: String, + returnType: Option[String], + lineNo: Option[Integer], + columnNo: Option[Integer] + ): Ast = + val typeFullName = + identifierType.orElse(Some(TypeConstants.Any)).map(typeInfoCalc.registerType) + val identifier = newIdentifierNode(identifierName, typeFullName.getOrElse("ANY")) + val maybeCorrespNode = scope.lookupVariable(identifierName).variableNode + + val fieldIdentifier = NewFieldIdentifier() + .code(fieldIdentifierName) + .canonicalName(fieldIdentifierName) + .lineNumber(lineNo) + .columnNumber(columnNo) + + val fieldAccessCode = s"$identifierName.$fieldIdentifierName" + val fieldAccess = + newOperatorCallNode( + Operators.fieldAccess, + fieldAccessCode, + returnType.orElse(Some(TypeConstants.Any)), + lineNo, + columnNo + ) + + val identifierAst = Ast(identifier) + val fieldIdentAst = Ast(fieldIdentifier) + + callAst(fieldAccess, Seq(identifierAst, fieldIdentAst)) + .withRefEdges(identifier, maybeCorrespNode.toList) + end fieldAccessAst + + private def astForNameExpr(nameExpr: NameExpr, expectedType: ExpectedType): Ast = + val name = nameExpr.getName.toString + val typeFullName = expressionReturnTypeFullName(nameExpr) + .orElse(expectedType.fullName) + .map(typeInfoCalc.registerType) + + tryWithSafeStackOverflow(nameExpr.resolve()) match + case Success(value) if value.isField => + val identifierName = if value.asField.isStatic then + // A static field represented by a NameExpr must belong to the class in which it's used. Static fields + // from other classes are represented by a FieldAccessExpr instead. + scope.enclosingTypeDecl.map(_.name).getOrElse( + s"${Defines.UnresolvedNamespace}.$name" + ) + else + NameConstants.This + + val identifierTypeFullName = + value match + case fieldDecl: ResolvedFieldDeclaration => + // TODO It is not quite correct to use the declaring classes type. + // Instead we should take the using classes type which is either the same or a + // sub class of the declaring class. + typeInfoCalc.fullName(fieldDecl.declaringType()) + + fieldAccessAst( + identifierName, + identifierTypeFullName, + name, + typeFullName, + line(nameExpr), + column(nameExpr) + ) + + case _ => + val identifier = + identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any)) + + val variableOption = scope + .lookupVariable(name) + .variableNode + .collect { + case parameter: NewMethodParameterIn => parameter + + case local: NewLocal => local + } + + variableOption.foldLeft(Ast(identifier))((ast, variableNode) => + ast.withRefEdge(identifier, variableNode) + ) + end match + end astForNameExpr + + private def argumentTypesForMethodLike( + maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration] + ): Option[List[String]] = + maybeResolvedMethodLike.toOption + .flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty())) + + private def initNode( + namespaceName: Option[String], + argumentTypes: Option[List[String]], + argsSize: Int, + code: String, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): NewCall = + val initSignature = argumentTypes match + case Some(tpe) => composeMethodLikeSignature(TypeConstants.Void, tpe) + case _ if argsSize == 0 => composeMethodLikeSignature(TypeConstants.Void, Nil) + case _ => composeUnresolvedSignature(argsSize) + val namespace = namespaceName.getOrElse(Defines.UnresolvedNamespace) + val initMethodFullName = + composeMethodFullName(namespace, Defines.ConstructorMethodName, initSignature) + NewCall() + .name(Defines.ConstructorMethodName) + .methodFullName(initMethodFullName) + .signature(initSignature) + .typeFullName(TypeConstants.Void) + .code(code) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + end initNode + + /** The below representation for constructor invocations and object creations was chosen for the + * sake of consistency with the Java frontend. It follows the bytecode approach of splitting a + * constructor call into separate `alloc` and `init` calls. + * + * There are two cases to consider. The first is a constructor invocation in an assignment, for + * example: + * + * Foo f = new Foo(42); + * + * is represented as + * + * Foo f = .alloc() f.init(42); + * + * The second case is a constructor invocation not in an assignment, for example as an argument + * to a method call. In this case, the representation does not stay as close to Java as in case + * 1. In particular, a new BLOCK is introduced to contain the constructor invocation. For + * example: + * + * foo(new Foo(42)); + * + * is represented as + * + * foo({ Foo temp = alloc(); temp.init(42); temp }) + * + * This is not valid Java code, but this representation is a decent compromise between staying + * faithful to Java and being consistent with the Java bytecode frontend. + */ + private def astForObjectCreationExpr( + expr: ObjectCreationExpr, + expectedType: ExpectedType + ): Ast = + val maybeResolvedExpr = tryWithSafeStackOverflow(expr.resolve()) + val argumentAsts = argAstsForCall(expr, maybeResolvedExpr, expr.getArguments) + + val typeFullName = + tryWithSafeStackOverflow(typeInfoCalc.fullName(expr.getType)).toOption.flatten + .orElse(scope.lookupType(expr.getTypeAsString)) + .orElse(expectedType.fullName) + + val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr) + + val allocNode = newOperatorCallNode( + Operators.alloc, + expr.toString, + typeFullName.orElse(Some(TypeConstants.Any)), + line(expr), + column(expr) + ) + + val initCall = initNode( + typeFullName.orElse(Some(TypeConstants.Any)), + argumentTypes, + argumentAsts.size, + expr.toString, + line(expr) + ) + + // Assume that a block ast is required, since there isn't enough information to decide otherwise. + // This simplifies logic elsewhere, and unnecessary blocks will be garbage collected soon. + val blockAst = blockAstForConstructorInvocation( + line(expr), + column(expr), + allocNode, + initCall, + argumentAsts + ) + + expr.getParentNode.toScala match + case Some(parent) + if parent.isInstanceOf[VariableDeclarator] || parent.isInstanceOf[AssignExpr] => + val partialConstructor = PartialConstructor(initCall, argumentAsts, blockAst) + partialConstructorQueue.append(partialConstructor) + Ast(allocNode) + + case _ => + blockAst + end astForObjectCreationExpr + + private var tempConstCount = 0 + private def blockAstForConstructorInvocation( + lineNumber: Option[Integer], + columnNumber: Option[Integer], + allocNode: NewCall, + initNode: NewCall, + args: Seq[Ast] + ): Ast = + val blockNode = NewBlock() + .lineNumber(lineNumber) + .columnNumber(columnNumber) + .typeFullName(allocNode.typeFullName) + + val tempName = "$obj" ++ tempConstCount.toString + tempConstCount += 1 + val identifier = newIdentifierNode(tempName, allocNode.typeFullName) + val identifierAst = Ast(identifier) + + val allocAst = Ast(allocNode) + + val assignmentNode = newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(allocNode.typeFullName) + ) + + val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst)) + + val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order) + val identifierForInit = identifierWithDefaultOrder.copy + val initWithDefaultOrder = initNode.order(PropertyDefaults.Order) + val initAst = callAst(initWithDefaultOrder, args, Some(Ast(identifierForInit))) + + val returnAst = Ast(identifierWithDefaultOrder.copy) + + Ast(blockNode) + .withChild(assignmentAst) + .withChild(initAst) + .withChild(returnAst) + end blockAstForConstructorInvocation + + private def astForThisExpr(expr: ThisExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) + + val identifier = + identifierNode(expr, expr.toString, expr.toString, typeFullName.getOrElse("ANY")) + val thisParam = scope.lookupVariable(NameConstants.This).variableNode + + thisParam.foreach { thisNode => + diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF) } - case _ => // Nothing to do in this case - } - - callAst(initNode, args.toList, Some(targetAst)) - } - - private def astsForVariableDecl(varDecl: VariableDeclarationExpr): Seq[Ast] = { - val locals = localsForVarDecl(varDecl) - val localAsts = locals.map { Ast(_) } - - locals.foreach { local => - scope.addLocal(local) - } - - val assignments = - assignmentsForVarDecl(varDecl.getVariables.asScala, line(varDecl), column(varDecl)) - - localAsts ++ assignments - } - - private def astForClassExpr(expr: ClassExpr): Ast = { - val someTypeFullName = Some(TypeConstants.Class) - val callNode = newOperatorCallNode(Operators.fieldAccess, expr.toString, someTypeFullName, line(expr), column(expr)) - - val identifierType = typeInfoCalc.fullName(expr.getType) - val identifier = identifierNode(expr, expr.getTypeAsString, expr.getTypeAsString, identifierType.getOrElse("ANY")) - val idAst = Ast(identifier) - - val fieldIdentifier = NewFieldIdentifier() - .canonicalName("class") - .code("class") - .lineNumber(line(expr)) - .columnNumber(column(expr)) - val fieldIdAst = Ast(fieldIdentifier) - - callAst(callNode, Seq(idAst, fieldIdAst)) - } - - private def astForConditionalExpr(expr: ConditionalExpr, expectedType: ExpectedType): Ast = { - val condAst = astsForExpression(expr.getCondition, ExpectedType.Boolean) - val thenAst = astsForExpression(expr.getThenExpr, expectedType) - val elseAst = astsForExpression(expr.getElseExpr, expectedType) - - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(thenAst.headOption.flatMap(_.rootType)) - .orElse(elseAst.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = - newOperatorCallNode(Operators.conditional, expr.toString, Some(typeFullName), line(expr), column(expr)) - - callAst(callNode, condAst ++ thenAst ++ elseAst) - } - - private def astForEnclosedExpression(expr: EnclosedExpr, expectedType: ExpectedType): Seq[Ast] = { - astsForExpression(expr.getInner, expectedType) - } - - private def astForFieldAccessExpr(expr: FieldAccessExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = - newOperatorCallNode(Operators.fieldAccess, expr.toString, Some(typeFullName), line(expr), column(expr)) - - val fieldIdentifier = expr.getName - val identifierAsts = astsForExpression(expr.getScope, ExpectedType.empty) - val fieldIdentifierNode = NewFieldIdentifier() - .canonicalName(fieldIdentifier.toString) - .lineNumber(line(fieldIdentifier)) - .columnNumber(column(fieldIdentifier)) - .code(fieldIdentifier.toString) - val fieldIdAst = Ast(fieldIdentifierNode) - - callAst(callNode, identifierAsts ++ Seq(fieldIdAst)) - } - - private def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = { - val booleanTypeFullName = Some(TypeConstants.Boolean) - val callNode = - newOperatorCallNode(Operators.instanceOf, expr.toString, booleanTypeFullName, line(expr), column(expr)) - - val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) - val typeFullName = typeInfoCalc.fullName(expr.getType).getOrElse(TypeConstants.Any) - val typeNode = - NewTypeRef() - .code(expr.getType.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr.getType)) - .typeFullName(typeFullName) - val typeAst = Ast(typeNode) - - callAst(callNode, exprAst ++ Seq(typeAst)) - } - - private def fieldAccessAst( - identifierName: String, - identifierType: Option[String], - fieldIdentifierName: String, - returnType: Option[String], - lineNo: Option[Integer], - columnNo: Option[Integer] - ): Ast = { - val typeFullName = identifierType.orElse(Some(TypeConstants.Any)).map(typeInfoCalc.registerType) - val identifier = newIdentifierNode(identifierName, typeFullName.getOrElse("ANY")) - val maybeCorrespNode = scope.lookupVariable(identifierName).variableNode - - val fieldIdentifier = NewFieldIdentifier() - .code(fieldIdentifierName) - .canonicalName(fieldIdentifierName) - .lineNumber(lineNo) - .columnNumber(columnNo) - - val fieldAccessCode = s"$identifierName.$fieldIdentifierName" - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - fieldAccessCode, - returnType.orElse(Some(TypeConstants.Any)), - lineNo, - columnNo - ) - - val identifierAst = Ast(identifier) - val fieldIdentAst = Ast(fieldIdentifier) - - callAst(fieldAccess, Seq(identifierAst, fieldIdentAst)) - .withRefEdges(identifier, maybeCorrespNode.toList) - } - - private def astForNameExpr(nameExpr: NameExpr, expectedType: ExpectedType): Ast = { - val name = nameExpr.getName.toString - val typeFullName = expressionReturnTypeFullName(nameExpr) - .orElse(expectedType.fullName) - .map(typeInfoCalc.registerType) - - tryWithSafeStackOverflow(nameExpr.resolve()) match { - case Success(value) if value.isField => - val identifierName = if (value.asField.isStatic) { - // A static field represented by a NameExpr must belong to the class in which it's used. Static fields - // from other classes are represented by a FieldAccessExpr instead. - scope.enclosingTypeDecl.map(_.name).getOrElse(s"${Defines.UnresolvedNamespace}.$name") - } else { - NameConstants.This + Ast(identifier) + + private def astForExplicitConstructorInvocation(stmt: ExplicitConstructorInvocationStmt): Ast = + val maybeResolved = tryWithSafeStackOverflow(stmt.resolve()) + val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments) + val argTypes = argumentTypesForMethodLike(maybeResolved) + + val typeFullName = maybeResolved.toOption + .map(_.declaringType()) + .flatMap(typeInfoCalc.fullName) + + val callRoot = initNode( + typeFullName.orElse(Some(TypeConstants.Any)), + argTypes, + args.size, + stmt.toString, + line(stmt), + column(stmt) + ) + + val thisNode = + newIdentifierNode(NameConstants.This, typeFullName.getOrElse(TypeConstants.Any)) + scope.lookupVariable(NameConstants.This).variableNode.foreach { thisParam => + diffGraph.addEdge(thisNode, thisParam, EdgeTypes.REF) } + val thisAst = Ast(thisNode) + + callAst(callRoot, args, Some(thisAst)) + end astForExplicitConstructorInvocation + + private def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = + // TODO: Implement missing handlers + // case _: MethodReferenceExpr => Seq() + // case _: PatternExpr => Seq() + // case _: SuperExpr => Seq() + // case _: SwitchExpr => Seq() + // case _: TypeExpr => Seq() + expression match + case _: AnnotationExpr => Seq() + case x: ArrayAccessExpr => Seq(astForArrayAccessExpr(x, expectedType)) + case x: ArrayCreationExpr => Seq(astForArrayCreationExpr(x, expectedType)) + case x: ArrayInitializerExpr => Seq(astForArrayInitializerExpr(x, expectedType)) + case x: AssignExpr => astsForAssignExpr(x, expectedType) + case x: BinaryExpr => Seq(astForBinaryExpr(x, expectedType)) + case x: CastExpr => Seq(astForCastExpr(x, expectedType)) + case x: ClassExpr => Seq(astForClassExpr(x)) + case x: ConditionalExpr => Seq(astForConditionalExpr(x, expectedType)) + case x: EnclosedExpr => astForEnclosedExpression(x, expectedType) + case x: FieldAccessExpr => Seq(astForFieldAccessExpr(x, expectedType)) + case x: InstanceOfExpr => Seq(astForInstanceOfExpr(x)) + case x: LambdaExpr => Seq(astForLambdaExpr(x, expectedType)) + case x: LiteralExpr => Seq(astForLiteralExpr(x)) + case x: MethodCallExpr => Seq(astForMethodCall(x, expectedType)) + case x: NameExpr => Seq(astForNameExpr(x, expectedType)) + case x: ObjectCreationExpr => Seq(astForObjectCreationExpr(x, expectedType)) + case x: SuperExpr => Seq(astForSuperExpr(x, expectedType)) + case x: ThisExpr => Seq(astForThisExpr(x, expectedType)) + case x: UnaryExpr => Seq(astForUnaryExpr(x, expectedType)) + case x: VariableDeclarationExpr => astsForVariableDecl(x) + case x => Seq(unknownAst(x)) + + private def unknownAst(node: Node): Ast = Ast(unknownNode(node, node.toString)) + + private def someWithDotSuffix(prefix: String): Option[String] = Some(s"$prefix.") + + private def codeForScopeExpr( + scopeExpr: Expression, + isScopeForStaticCall: Boolean + ): Option[String] = + scopeExpr match + case scope: NameExpr => someWithDotSuffix(scope.getNameAsString) + + case fieldAccess: FieldAccessExpr => + val maybeScopeString = + codeForScopeExpr(fieldAccess.getScope, isScopeForStaticCall = false) + val name = fieldAccess.getNameAsString + maybeScopeString + .map { scopeString => + s"$scopeString$name" + } + .orElse(Some(name)) + .flatMap(someWithDotSuffix) + + case _: SuperExpr => someWithDotSuffix(NameConstants.Super) + + case _: ThisExpr => someWithDotSuffix(NameConstants.This) + + case scopeMethodCall: MethodCallExpr => + codePrefixForMethodCall(scopeMethodCall) match + case "" => Some("") + case prefix => + val argumentsCode = getArgumentCodeString(scopeMethodCall.getArguments) + someWithDotSuffix( + s"$prefix${scopeMethodCall.getNameAsString}($argumentsCode)" + ) + + case objectCreationExpr: ObjectCreationExpr => + val typeName = objectCreationExpr.getTypeAsString + val argumentsString = getArgumentCodeString(objectCreationExpr.getArguments) + someWithDotSuffix(s"new $typeName($argumentsString)") + + case _ => None + + private def codePrefixForMethodCall(call: MethodCallExpr): String = + tryWithSafeStackOverflow(call.resolve()) match + case Success(resolvedCall) => + call.getScope.toScala + .flatMap(codeForScopeExpr(_, resolvedCall.isStatic)) + .getOrElse(if resolvedCall.isStatic then "" else s"${NameConstants.This}.") + + case _ => + // If the call is unresolvable, we cannot make a good guess about what the prefix should be + "" + + private def createObjectNode( + typeFullName: Option[String], + call: MethodCallExpr, + dispatchType: String + ): Option[NewIdentifier] = + val maybeScope = call.getScope.toScala + + Option.when(maybeScope.isDefined || dispatchType == DispatchTypes.DYNAMIC_DISPATCH) { + val name = maybeScope.map(_.toString).getOrElse(NameConstants.This) + identifierNode(call, name, name, typeFullName.getOrElse("ANY")) + } + + private def nextLambdaName(): String = + s"$LambdaNamePrefix${lambdaKeyPool.next}" + + private def nextIndexName(): String = + s"$IndexNamePrefix${indexKeyPool.next}" + + private def nextIterableName(): String = + s"$IterableNamePrefix${iterableKeyPool.next}" + + private def genericParamTypeMapForLambda(expectedType: ExpectedType) + : ResolvedTypeParametersMap = + expectedType.resolvedType + // This should always be true for correct code + .collect { case r: ResolvedReferenceType => r } + .map(_.typeParametersMap()) + .getOrElse(new ResolvedTypeParametersMap.Builder().build()) + + private def buildParamListForLambda( + expr: LambdaExpr, + maybeBoundMethod: Option[ResolvedMethodDeclaration], + expectedTypeParamTypes: ResolvedTypeParametersMap + ): Seq[Ast] = + val lambdaParameters = expr.getParameters.asScala.toList + val paramTypesList = maybeBoundMethod match + case Some(resolvedMethod) => + val resolvedParameters = + (0 until resolvedMethod.getNumberOfParams).map(resolvedMethod.getParam) + + // Substitute generic typeParam with the expected type if it can be found; leave unchanged otherwise. + resolvedParameters.map(param => Try(param.getType)).map { + case Success(resolvedType: ResolvedTypeVariable) => + val typ = expectedTypeParamTypes.getValue(resolvedType.asTypeParameter) + typeInfoCalc.fullName(typ) + + case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) + + case Failure(_) => None + } + + case None => + // Unless types are explicitly specified in the lambda definition, + // this will yield the erased types which is why the actual lambda + // expression parameters are only used as a fallback. + lambdaParameters + .map(_.getType) + .map(typeInfoCalc.fullName) + + if paramTypesList.sizeIs != lambdaParameters.size then + logger.debug( + s"Found different number lambda params and param types for $expr. Some parameters will be missing." + ) + + val parameterNodes = lambdaParameters + .zip(paramTypesList) + .zipWithIndex + .map { case ((param, maybeType), idx) => + val name = param.getNameAsString + val typeFullName = maybeType.getOrElse(TypeConstants.Any) + val code = s"$typeFullName $name" + val evalStrat = + if param.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE + else EvaluationStrategies.BY_SHARING + val paramNode = NewMethodParameterIn() + .name(name) + .index(idx + 1) + .order(idx + 1) + .code(code) + .evaluationStrategy(evalStrat) + .typeFullName(typeFullName) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + typeInfoCalc.registerType(typeFullName) + paramNode + } - val identifierTypeFullName = - value match { - case fieldDecl: ResolvedFieldDeclaration => - // TODO It is not quite correct to use the declaring classes type. - // Instead we should take the using classes type which is either the same or a - // sub class of the declaring class. - typeInfoCalc.fullName(fieldDecl.declaringType()) - } - - fieldAccessAst(identifierName, identifierTypeFullName, name, typeFullName, line(nameExpr), column(nameExpr)) - - case _ => - val identifier = identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any)) - - val variableOption = scope - .lookupVariable(name) - .variableNode - .collect { - case parameter: NewMethodParameterIn => parameter - - case local: NewLocal => local - } - - variableOption.foldLeft(Ast(identifier))((ast, variableNode) => ast.withRefEdge(identifier, variableNode)) - } - - } - - private def argumentTypesForMethodLike( - maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration] - ): Option[List[String]] = { - maybeResolvedMethodLike.toOption - .flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty())) - } - - private def initNode( - namespaceName: Option[String], - argumentTypes: Option[List[String]], - argsSize: Int, - code: String, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): NewCall = { - val initSignature = argumentTypes match { - case Some(tpe) => composeMethodLikeSignature(TypeConstants.Void, tpe) - case _ if argsSize == 0 => composeMethodLikeSignature(TypeConstants.Void, Nil) - case _ => composeUnresolvedSignature(argsSize) - } - val namespace = namespaceName.getOrElse(Defines.UnresolvedNamespace) - val initMethodFullName = composeMethodFullName(namespace, Defines.ConstructorMethodName, initSignature) - NewCall() - .name(Defines.ConstructorMethodName) - .methodFullName(initMethodFullName) - .signature(initSignature) - .typeFullName(TypeConstants.Void) - .code(code) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - } - - /** The below representation for constructor invocations and object creations was chosen for the sake of consistency - * with the Java frontend. It follows the bytecode approach of splitting a constructor call into separate `alloc` and - * `init` calls. - * - * There are two cases to consider. The first is a constructor invocation in an assignment, for example: - * - * Foo f = new Foo(42); - * - * is represented as - * - * Foo f = .alloc() f.init(42); - * - * The second case is a constructor invocation not in an assignment, for example as an argument to a method call. In - * this case, the representation does not stay as close to Java as in case - * 1. In particular, a new BLOCK is introduced to contain the constructor invocation. For example: - * - * foo(new Foo(42)); - * - * is represented as - * - * foo({ Foo temp = alloc(); temp.init(42); temp }) - * - * This is not valid Java code, but this representation is a decent compromise between staying faithful to Java and - * being consistent with the Java bytecode frontend. - */ - private def astForObjectCreationExpr(expr: ObjectCreationExpr, expectedType: ExpectedType): Ast = { - val maybeResolvedExpr = tryWithSafeStackOverflow(expr.resolve()) - val argumentAsts = argAstsForCall(expr, maybeResolvedExpr, expr.getArguments) - - val typeFullName = tryWithSafeStackOverflow(typeInfoCalc.fullName(expr.getType)).toOption.flatten - .orElse(scope.lookupType(expr.getTypeAsString)) - .orElse(expectedType.fullName) - - val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr) - - val allocNode = newOperatorCallNode( - Operators.alloc, - expr.toString, - typeFullName.orElse(Some(TypeConstants.Any)), - line(expr), - column(expr) - ) - - val initCall = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), - argumentTypes, - argumentAsts.size, - expr.toString, - line(expr) - ) - - // Assume that a block ast is required, since there isn't enough information to decide otherwise. - // This simplifies logic elsewhere, and unnecessary blocks will be garbage collected soon. - val blockAst = blockAstForConstructorInvocation(line(expr), column(expr), allocNode, initCall, argumentAsts) - - expr.getParentNode.toScala match { - case Some(parent) if parent.isInstanceOf[VariableDeclarator] || parent.isInstanceOf[AssignExpr] => - val partialConstructor = PartialConstructor(initCall, argumentAsts, blockAst) - partialConstructorQueue.append(partialConstructor) - Ast(allocNode) - - case _ => - blockAst - } - } - - private var tempConstCount = 0 - private def blockAstForConstructorInvocation( - lineNumber: Option[Integer], - columnNumber: Option[Integer], - allocNode: NewCall, - initNode: NewCall, - args: Seq[Ast] - ): Ast = { - val blockNode = NewBlock() - .lineNumber(lineNumber) - .columnNumber(columnNumber) - .typeFullName(allocNode.typeFullName) - - val tempName = "$obj" ++ tempConstCount.toString - tempConstCount += 1 - val identifier = newIdentifierNode(tempName, allocNode.typeFullName) - val identifierAst = Ast(identifier) - - val allocAst = Ast(allocNode) - - val assignmentNode = newOperatorCallNode(Operators.assignment, PropertyDefaults.Code, Some(allocNode.typeFullName)) - - val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst)) - - val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order) - val identifierForInit = identifierWithDefaultOrder.copy - val initWithDefaultOrder = initNode.order(PropertyDefaults.Order) - val initAst = callAst(initWithDefaultOrder, args, Some(Ast(identifierForInit))) - - val returnAst = Ast(identifierWithDefaultOrder.copy) - - Ast(blockNode) - .withChild(assignmentAst) - .withChild(initAst) - .withChild(returnAst) - } - - private def astForThisExpr(expr: ThisExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - - val identifier = identifierNode(expr, expr.toString, expr.toString, typeFullName.getOrElse("ANY")) - val thisParam = scope.lookupVariable(NameConstants.This).variableNode - - thisParam.foreach { thisNode => - diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF) - } - - Ast(identifier) - } - - private def astForExplicitConstructorInvocation(stmt: ExplicitConstructorInvocationStmt): Ast = { - val maybeResolved = tryWithSafeStackOverflow(stmt.resolve()) - val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments) - val argTypes = argumentTypesForMethodLike(maybeResolved) - - val typeFullName = maybeResolved.toOption - .map(_.declaringType()) - .flatMap(typeInfoCalc.fullName) - - val callRoot = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), - argTypes, - args.size, - stmt.toString, - line(stmt), - column(stmt) - ) - - val thisNode = newIdentifierNode(NameConstants.This, typeFullName.getOrElse(TypeConstants.Any)) - scope.lookupVariable(NameConstants.This).variableNode.foreach { thisParam => - diffGraph.addEdge(thisNode, thisParam, EdgeTypes.REF) - } - val thisAst = Ast(thisNode) - - callAst(callRoot, args, Some(thisAst)) - } - - private def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = { - // TODO: Implement missing handlers - // case _: MethodReferenceExpr => Seq() - // case _: PatternExpr => Seq() - // case _: SuperExpr => Seq() - // case _: SwitchExpr => Seq() - // case _: TypeExpr => Seq() - expression match { - case _: AnnotationExpr => Seq() - case x: ArrayAccessExpr => Seq(astForArrayAccessExpr(x, expectedType)) - case x: ArrayCreationExpr => Seq(astForArrayCreationExpr(x, expectedType)) - case x: ArrayInitializerExpr => Seq(astForArrayInitializerExpr(x, expectedType)) - case x: AssignExpr => astsForAssignExpr(x, expectedType) - case x: BinaryExpr => Seq(astForBinaryExpr(x, expectedType)) - case x: CastExpr => Seq(astForCastExpr(x, expectedType)) - case x: ClassExpr => Seq(astForClassExpr(x)) - case x: ConditionalExpr => Seq(astForConditionalExpr(x, expectedType)) - case x: EnclosedExpr => astForEnclosedExpression(x, expectedType) - case x: FieldAccessExpr => Seq(astForFieldAccessExpr(x, expectedType)) - case x: InstanceOfExpr => Seq(astForInstanceOfExpr(x)) - case x: LambdaExpr => Seq(astForLambdaExpr(x, expectedType)) - case x: LiteralExpr => Seq(astForLiteralExpr(x)) - case x: MethodCallExpr => Seq(astForMethodCall(x, expectedType)) - case x: NameExpr => Seq(astForNameExpr(x, expectedType)) - case x: ObjectCreationExpr => Seq(astForObjectCreationExpr(x, expectedType)) - case x: SuperExpr => Seq(astForSuperExpr(x, expectedType)) - case x: ThisExpr => Seq(astForThisExpr(x, expectedType)) - case x: UnaryExpr => Seq(astForUnaryExpr(x, expectedType)) - case x: VariableDeclarationExpr => astsForVariableDecl(x) - case x => Seq(unknownAst(x)) - } - } - - private def unknownAst(node: Node): Ast = Ast(unknownNode(node, node.toString)) - - private def someWithDotSuffix(prefix: String): Option[String] = Some(s"$prefix.") - - private def codeForScopeExpr(scopeExpr: Expression, isScopeForStaticCall: Boolean): Option[String] = { - scopeExpr match { - case scope: NameExpr => someWithDotSuffix(scope.getNameAsString) - - case fieldAccess: FieldAccessExpr => - val maybeScopeString = codeForScopeExpr(fieldAccess.getScope, isScopeForStaticCall = false) - val name = fieldAccess.getNameAsString - maybeScopeString - .map { scopeString => - s"$scopeString$name" - } - .orElse(Some(name)) - .flatMap(someWithDotSuffix) - - case _: SuperExpr => someWithDotSuffix(NameConstants.Super) - - case _: ThisExpr => someWithDotSuffix(NameConstants.This) - - case scopeMethodCall: MethodCallExpr => - codePrefixForMethodCall(scopeMethodCall) match { - case "" => Some("") - case prefix => - val argumentsCode = getArgumentCodeString(scopeMethodCall.getArguments) - someWithDotSuffix(s"$prefix${scopeMethodCall.getNameAsString}($argumentsCode)") + parameterNodes.foreach { paramNode => + scope.addParameter(paramNode) } - case objectCreationExpr: ObjectCreationExpr => - val typeName = objectCreationExpr.getTypeAsString - val argumentsString = getArgumentCodeString(objectCreationExpr.getArguments) - someWithDotSuffix(s"new $typeName($argumentsString)") - - case _ => None - } - } - - private def codePrefixForMethodCall(call: MethodCallExpr): String = { - tryWithSafeStackOverflow(call.resolve()) match { - case Success(resolvedCall) => - call.getScope.toScala - .flatMap(codeForScopeExpr(_, resolvedCall.isStatic)) - .getOrElse(if (resolvedCall.isStatic) "" else s"${NameConstants.This}.") - - case _ => - // If the call is unresolvable, we cannot make a good guess about what the prefix should be - "" - } - } - - private def createObjectNode( - typeFullName: Option[String], - call: MethodCallExpr, - dispatchType: String - ): Option[NewIdentifier] = { - val maybeScope = call.getScope.toScala - - Option.when(maybeScope.isDefined || dispatchType == DispatchTypes.DYNAMIC_DISPATCH) { - val name = maybeScope.map(_.toString).getOrElse(NameConstants.This) - identifierNode(call, name, name, typeFullName.getOrElse("ANY")) - } - } - - private def nextLambdaName(): String = { - s"$LambdaNamePrefix${lambdaKeyPool.next}" - } - - private def nextIndexName(): String = { - s"$IndexNamePrefix${indexKeyPool.next}" - } - - private def nextIterableName(): String = { - s"$IterableNamePrefix${iterableKeyPool.next}" - } - - private def genericParamTypeMapForLambda(expectedType: ExpectedType): ResolvedTypeParametersMap = { - expectedType.resolvedType - // This should always be true for correct code - .collect { case r: ResolvedReferenceType => r } - .map(_.typeParametersMap()) - .getOrElse(new ResolvedTypeParametersMap.Builder().build()) - } - - private def buildParamListForLambda( - expr: LambdaExpr, - maybeBoundMethod: Option[ResolvedMethodDeclaration], - expectedTypeParamTypes: ResolvedTypeParametersMap - ): Seq[Ast] = { - val lambdaParameters = expr.getParameters.asScala.toList - val paramTypesList = maybeBoundMethod match { - case Some(resolvedMethod) => - val resolvedParameters = (0 until resolvedMethod.getNumberOfParams).map(resolvedMethod.getParam) - - // Substitute generic typeParam with the expected type if it can be found; leave unchanged otherwise. - resolvedParameters.map(param => Try(param.getType)).map { - case Success(resolvedType: ResolvedTypeVariable) => - val typ = expectedTypeParamTypes.getValue(resolvedType.asTypeParameter) - typeInfoCalc.fullName(typ) - - case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) - - case Failure(_) => None + parameterNodes.map(Ast(_)) + end buildParamListForLambda + + private def getLambdaReturnType( + maybeResolvedLambdaType: Option[ResolvedType], + maybeBoundMethod: Option[ResolvedMethodDeclaration], + expectedTypeParamTypes: ResolvedTypeParametersMap + ): Option[String] = + val maybeBoundMethodReturnType = maybeBoundMethod.flatMap { boundMethod => + Try(boundMethod.getReturnType).collect { + case returnType: ResolvedTypeVariable => + expectedTypeParamTypes.getValue(returnType.asTypeParameter) + case other => other + }.toOption } - case None => - // Unless types are explicitly specified in the lambda definition, - // this will yield the erased types which is why the actual lambda - // expression parameters are only used as a fallback. - lambdaParameters - .map(_.getType) - .map(typeInfoCalc.fullName) - } - - if (paramTypesList.sizeIs != lambdaParameters.size) { - logger.debug(s"Found different number lambda params and param types for $expr. Some parameters will be missing.") - } - - val parameterNodes = lambdaParameters - .zip(paramTypesList) - .zipWithIndex - .map { case ((param, maybeType), idx) => - val name = param.getNameAsString - val typeFullName = maybeType.getOrElse(TypeConstants.Any) - val code = s"$typeFullName $name" - val evalStrat = - if (param.getType.isPrimitiveType) EvaluationStrategies.BY_VALUE else EvaluationStrategies.BY_SHARING - val paramNode = NewMethodParameterIn() - .name(name) - .index(idx + 1) - .order(idx + 1) - .code(code) - .evaluationStrategy(evalStrat) - .typeFullName(typeFullName) - .lineNumber(line(expr)) - .columnNumber(column(expr)) - typeInfoCalc.registerType(typeFullName) - paramNode - } - - parameterNodes.foreach { paramNode => - scope.addParameter(paramNode) - } - - parameterNodes.map(Ast(_)) - } - - private def getLambdaReturnType( - maybeResolvedLambdaType: Option[ResolvedType], - maybeBoundMethod: Option[ResolvedMethodDeclaration], - expectedTypeParamTypes: ResolvedTypeParametersMap - ): Option[String] = { - val maybeBoundMethodReturnType = maybeBoundMethod.flatMap { boundMethod => - Try(boundMethod.getReturnType).collect { - case returnType: ResolvedTypeVariable => expectedTypeParamTypes.getValue(returnType.asTypeParameter) - case other => other - }.toOption - } - - val returnType = maybeBoundMethodReturnType.orElse(maybeResolvedLambdaType) - returnType.flatMap(typeInfoCalc.fullName) - } - - private def closureBindingsForCapturedNodes(lambdaMethodName: String): List[ClosureBindingEntry] = { - scope.capturedVariables.map { capturedNode => - val closureBindingId = s"$filename:$lambdaMethodName:${capturedNode.name}" - val closureBindingNode = - newClosureBindingNode(closureBindingId, capturedNode.name, EvaluationStrategies.BY_SHARING) - passes.ClosureBindingEntry(capturedNode, closureBindingNode) - } - } - - private def localsForCapturedNodes(closureBindingEntries: List[ClosureBindingEntry]): List[NewLocal] = { - val localsForCaptured = - closureBindingEntries.map { case ClosureBindingEntry(node, binding) => - val local = NewLocal() - .name(node.name) - .code(node.name) - .closureBindingId(binding.closureBindingId) - .typeFullName(node.typeFullName) - local - } - localsForCaptured.foreach { local => scope.addLocal(local) } - localsForCaptured - } - - private def astForLambdaBody( - body: Statement, - localsForCapturedVars: Seq[NewLocal], - returnType: Option[String] - ): Ast = { - body match { - case block: BlockStmt => astForBlockStatement(block, prefixAsts = localsForCapturedVars.map(Ast(_))) - - case stmt => - val blockAst = Ast(NewBlock().lineNumber(line(body))) - val bodyAst = if (returnType.contains(TypeConstants.Void)) { - astsForStatement(stmt) - } else { - val returnNode = - NewReturn() - .code(s"return ${body.toString}") - .lineNumber(line(body)) - val returnArgs = astsForStatement(stmt) - Seq(returnAst(returnNode, returnArgs)) + val returnType = maybeBoundMethodReturnType.orElse(maybeResolvedLambdaType) + returnType.flatMap(typeInfoCalc.fullName) + + private def closureBindingsForCapturedNodes(lambdaMethodName: String) + : List[ClosureBindingEntry] = + scope.capturedVariables.map { capturedNode => + val closureBindingId = s"$filename:$lambdaMethodName:${capturedNode.name}" + val closureBindingNode = + newClosureBindingNode( + closureBindingId, + capturedNode.name, + EvaluationStrategies.BY_SHARING + ) + passes.ClosureBindingEntry(capturedNode, closureBindingNode) } - blockAst - .withChildren(localsForCapturedVars.map(Ast(_))) - .withChildren(bodyAst) - } - } - - private def lambdaMethodSignature(returnType: Option[String], parameters: Seq[Ast]): String = { - val maybeParameterTypes = toOptionList(parameters.map(_.rootType)) - val containsEmptyType = maybeParameterTypes.exists(_.contains(ParameterDefaults.TypeFullName)) - - (returnType, maybeParameterTypes) match { - case (Some(returnTpe), Some(parameterTpes)) if !containsEmptyType => - composeMethodLikeSignature(returnTpe, parameterTpes) - - case _ => composeUnresolvedSignature(parameters.size) - } - } - - private def createLambdaMethodNode( - lambdaName: String, - parameters: Seq[Ast], - returnType: Option[String] - ): NewMethod = { - val enclosingTypeName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val signature = lambdaMethodSignature(returnType, parameters) - val lambdaFullName = composeMethodFullName(enclosingTypeName, lambdaName, signature) - - NewMethod() - .name(lambdaName) - .fullName(lambdaFullName) - .signature(signature) - .filename(filename) - .code("") - } - - private def addClosureBindingsToDiffGraph( - bindingEntries: Iterable[ClosureBindingEntry], - methodRef: NewMethodRef - ): Unit = { - bindingEntries.foreach { case ClosureBindingEntry(nodeTypeInfo, closureBinding) => - diffGraph.addNode(closureBinding) - diffGraph.addEdge(closureBinding, nodeTypeInfo.node, EdgeTypes.REF) - diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE) - } - } - - private def createAndPushLambdaMethod( - expr: LambdaExpr, - lambdaMethodName: String, - implementedInfo: LambdaImplementedInfo, - localsForCaptured: Seq[NewLocal], - expectedLambdaType: ExpectedType - ): NewMethod = { - val implementedMethod = implementedInfo.implementedMethod - val implementedInterface = implementedInfo.implementedInterface - - // We need to get this information from the expected type as the JavaParser - // symbol solver returns the erased types when resolving the lambda itself. - val expectedTypeParamTypes = genericParamTypeMapForLambda(expectedLambdaType) - val parametersWithoutThis = buildParamListForLambda(expr, implementedMethod, expectedTypeParamTypes) - - val returnType = getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) - - val lambdaMethodBody = astForLambdaBody(expr.getBody, localsForCaptured, returnType) - - val thisParam = lambdaMethodBody.nodes - .collect { case identifier: NewIdentifier => identifier } - .find { identifier => identifier.name == NameConstants.This || identifier.name == NameConstants.Super } - .map { _ => - val typeFullName = scope.enclosingTypeDeclFullName - Ast(thisNodeForMethod(typeFullName, line(expr))) - } - .toList - - val parameters = thisParam ++ parametersWithoutThis - - val lambdaMethodNode = createLambdaMethodNode(lambdaMethodName, parametersWithoutThis, returnType) - val returnNode = newMethodReturnNode(returnType.getOrElse(TypeConstants.Any), None, line(expr), column(expr)) - val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) - val staticModifier = Option.when(thisParam.isEmpty)(newModifierNode(ModifierTypes.STATIC)) - val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) - - val modifiers = List(virtualModifier, staticModifier, privateModifier).flatten.map(Ast(_)) - - val lambdaParameterNamesToNodes = - parameters - .flatMap(_.root) - .collect { case param: NewMethodParameterIn => param } - .map { param => param.name -> param } - .toMap - - val identifiersMatchingParams = lambdaMethodBody.nodes - .collect { case identifier: NewIdentifier => identifier } - .filter { identifier => lambdaParameterNamesToNodes.contains(identifier.name) } - - val lambdaMethodAstWithoutRefs = - Ast(lambdaMethodNode) - .withChildren(parameters) - .withChild(lambdaMethodBody) - .withChild(Ast(returnNode)) - .withChildren(modifiers) - - val lambdaMethodAst = identifiersMatchingParams.foldLeft(lambdaMethodAstWithoutRefs)((ast, identifier) => - ast.withRefEdge(identifier, lambdaParameterNamesToNodes(identifier.name)) - ) - - scope.addLambdaMethod(lambdaMethodAst) - - lambdaMethodNode - } - - private def createAndPushLambdaTypeDecl( - lambdaMethodNode: NewMethod, - implementedInfo: LambdaImplementedInfo - ): NewTypeDecl = { - val inheritsFromTypeFullName = - implementedInfo.implementedInterface - .flatMap(typeInfoCalc.fullName) - .orElse(Some(TypeConstants.Object)) - .toList - - typeInfoCalc.registerType(lambdaMethodNode.fullName) - val lambdaTypeDeclNode = - NewTypeDecl() - .fullName(lambdaMethodNode.fullName) - .name(lambdaMethodNode.name) - .inheritsFromTypeFullName(inheritsFromTypeFullName) - scope.addLocalDecl(Ast(lambdaTypeDeclNode)) - - lambdaTypeDeclNode - } - - private def getLambdaImplementedInfo(expr: LambdaExpr, expectedType: ExpectedType): LambdaImplementedInfo = { - val maybeImplementedType = { - val maybeResolved = tryWithSafeStackOverflow(expr.calculateResolvedType()) - maybeResolved.toOption - .orElse(expectedType.resolvedType) - .collect { case refType: ResolvedReferenceType => refType } - } - - val maybeImplementedInterface = maybeImplementedType.flatMap(_.getTypeDeclaration.toScala) - - if (maybeImplementedInterface.isEmpty) { - val location = s"$filename:${line(expr)}:${column(expr)}" - logger.debug( - s"Could not resolve the interface implemented by a lambda. Type info may be missing: $location. Type info may be missing." - ) - } - - val maybeBoundMethod = maybeImplementedInterface.flatMap { interface => - interface.getDeclaredMethods.asScala - .filter(_.isAbstract) - .filterNot { method => - // Filter out java.lang.Object methods re-declared by the interface as these are also considered abstract. - // See https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for details. - Try(method.getSignature) match { - case Success(signature) => ObjectMethodSignatures.contains(signature) - case Failure(_) => - false // If the signature could not be calculated, it's probably not a standard object method. - } + private def localsForCapturedNodes(closureBindingEntries: List[ClosureBindingEntry]) + : List[NewLocal] = + val localsForCaptured = + closureBindingEntries.map { case ClosureBindingEntry(node, binding) => + val local = NewLocal() + .name(node.name) + .code(node.name) + .closureBindingId(binding.closureBindingId) + .typeFullName(node.typeFullName) + local + } + localsForCaptured.foreach { local => scope.addLocal(local) } + localsForCaptured + + private def astForLambdaBody( + body: Statement, + localsForCapturedVars: Seq[NewLocal], + returnType: Option[String] + ): Ast = + body match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = localsForCapturedVars.map(Ast(_))) + + case stmt => + val blockAst = Ast(NewBlock().lineNumber(line(body))) + val bodyAst = if returnType.contains(TypeConstants.Void) then + astsForStatement(stmt) + else + val returnNode = + NewReturn() + .code(s"return ${body.toString}") + .lineNumber(line(body)) + val returnArgs = astsForStatement(stmt) + Seq(returnAst(returnNode, returnArgs)) + + blockAst + .withChildren(localsForCapturedVars.map(Ast(_))) + .withChildren(bodyAst) + + private def lambdaMethodSignature(returnType: Option[String], parameters: Seq[Ast]): String = + val maybeParameterTypes = toOptionList(parameters.map(_.rootType)) + val containsEmptyType = + maybeParameterTypes.exists(_.contains(ParameterDefaults.TypeFullName)) + + (returnType, maybeParameterTypes) match + case (Some(returnTpe), Some(parameterTpes)) if !containsEmptyType => + composeMethodLikeSignature(returnTpe, parameterTpes) + + case _ => composeUnresolvedSignature(parameters.size) + + private def createLambdaMethodNode( + lambdaName: String, + parameters: Seq[Ast], + returnType: Option[String] + ): NewMethod = + val enclosingTypeName = + scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val signature = lambdaMethodSignature(returnType, parameters) + val lambdaFullName = composeMethodFullName(enclosingTypeName, lambdaName, signature) + + NewMethod() + .name(lambdaName) + .fullName(lambdaFullName) + .signature(signature) + .filename(filename) + .code("") + + private def addClosureBindingsToDiffGraph( + bindingEntries: Iterable[ClosureBindingEntry], + methodRef: NewMethodRef + ): Unit = + bindingEntries.foreach { case ClosureBindingEntry(nodeTypeInfo, closureBinding) => + diffGraph.addNode(closureBinding) + diffGraph.addEdge(closureBinding, nodeTypeInfo.node, EdgeTypes.REF) + diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE) } - .headOption - } - - LambdaImplementedInfo(maybeImplementedType, maybeBoundMethod) - } - - // TODO: All of this will be thrown out, probably - private def astForLambdaExpr(expr: LambdaExpr, expectedType: ExpectedType): Ast = { - scope.pushMethodScope(NewMethod(), expectedType) - - val lambdaMethodName = nextLambdaName() - - val closureBindingsForCapturedVars = closureBindingsForCapturedNodes(lambdaMethodName) - val localsForCaptured = localsForCapturedNodes(closureBindingsForCapturedVars) - val implementedInfo = getLambdaImplementedInfo(expr, expectedType) - val lambdaMethodNode = - createAndPushLambdaMethod(expr, lambdaMethodName, implementedInfo, localsForCaptured, expectedType) - val typeNameLookup = lambdaMethodNode.fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - val methodRef = - NewMethodRef() - .methodFullName(lambdaMethodNode.fullName) - .typeFullName(lambdaMethodNode.fullName) - .code(lambdaMethodNode.fullName) - .dynamicTypeHintFullName(packagesJarMappings.getOrElse(typeNameLookup, mutable.Set.empty).toSeq) - - addClosureBindingsToDiffGraph(closureBindingsForCapturedVars, methodRef) - - val interfaceBinding = implementedInfo.implementedMethod.map { implementedMethod => - newBindingNode(implementedMethod.getName, lambdaMethodNode.signature, lambdaMethodNode.fullName) - } - - val bindingTable = getLambdaBindingTable( - LambdaBindingInfo(lambdaMethodNode.fullName, implementedInfo.implementedInterface, interfaceBinding) - ) - - val lambdaTypeDeclNode = createAndPushLambdaTypeDecl(lambdaMethodNode, implementedInfo) - createBindingNodes(lambdaTypeDeclNode, bindingTable) - - scope.popScope() - Ast(methodRef) - } - - private def astForLiteralExpr(expr: LiteralExpr): Ast = { - val typeFullName = expressionReturnTypeFullName(expr).getOrElse(TypeConstants.Any) - val literalNode = - NewLiteral() - .code(expr.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr)) - .typeFullName(typeFullName) - Ast(literalNode) - } - - private def getExpectedParamType(maybeResolvedCall: Try[ResolvedMethodLikeDeclaration], idx: Int): ExpectedType = { - maybeResolvedCall.toOption - .map { methodDecl => - val paramCount = methodDecl.getNumberOfParams - - val resolvedType = if (idx < paramCount) { - Some(methodDecl.getParam(idx).getType) - } else if (paramCount > 0 && methodDecl.getParam(paramCount - 1).isVariadic) { - Some(methodDecl.getParam(paramCount - 1).getType) - } else { - None + + private def createAndPushLambdaMethod( + expr: LambdaExpr, + lambdaMethodName: String, + implementedInfo: LambdaImplementedInfo, + localsForCaptured: Seq[NewLocal], + expectedLambdaType: ExpectedType + ): NewMethod = + val implementedMethod = implementedInfo.implementedMethod + val implementedInterface = implementedInfo.implementedInterface + + // We need to get this information from the expected type as the JavaParser + // symbol solver returns the erased types when resolving the lambda itself. + val expectedTypeParamTypes = genericParamTypeMapForLambda(expectedLambdaType) + val parametersWithoutThis = + buildParamListForLambda(expr, implementedMethod, expectedTypeParamTypes) + + val returnType = + getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) + + val lambdaMethodBody = astForLambdaBody(expr.getBody, localsForCaptured, returnType) + + val thisParam = lambdaMethodBody.nodes + .collect { case identifier: NewIdentifier => identifier } + .find { identifier => + identifier.name == NameConstants.This || identifier.name == NameConstants.Super + } + .map { _ => + val typeFullName = scope.enclosingTypeDeclFullName + Ast(thisNodeForMethod(typeFullName, line(expr))) + } + .toList + + val parameters = thisParam ++ parametersWithoutThis + + val lambdaMethodNode = + createLambdaMethodNode(lambdaMethodName, parametersWithoutThis, returnType) + val returnNode = newMethodReturnNode( + returnType.getOrElse(TypeConstants.Any), + None, + line(expr), + column(expr) + ) + val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) + val staticModifier = Option.when(thisParam.isEmpty)(newModifierNode(ModifierTypes.STATIC)) + val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) + + val modifiers = List(virtualModifier, staticModifier, privateModifier).flatten.map(Ast(_)) + + val lambdaParameterNamesToNodes = + parameters + .flatMap(_.root) + .collect { case param: NewMethodParameterIn => param } + .map { param => param.name -> param } + .toMap + + val identifiersMatchingParams = lambdaMethodBody.nodes + .collect { case identifier: NewIdentifier => identifier } + .filter { identifier => lambdaParameterNamesToNodes.contains(identifier.name) } + + val lambdaMethodAstWithoutRefs = + Ast(lambdaMethodNode) + .withChildren(parameters) + .withChild(lambdaMethodBody) + .withChild(Ast(returnNode)) + .withChildren(modifiers) + + val lambdaMethodAst = + identifiersMatchingParams.foldLeft(lambdaMethodAstWithoutRefs)((ast, identifier) => + ast.withRefEdge(identifier, lambdaParameterNamesToNodes(identifier.name)) + ) + + scope.addLambdaMethod(lambdaMethodAst) + + lambdaMethodNode + end createAndPushLambdaMethod + + private def createAndPushLambdaTypeDecl( + lambdaMethodNode: NewMethod, + implementedInfo: LambdaImplementedInfo + ): NewTypeDecl = + val inheritsFromTypeFullName = + implementedInfo.implementedInterface + .flatMap(typeInfoCalc.fullName) + .orElse(Some(TypeConstants.Object)) + .toList + + typeInfoCalc.registerType(lambdaMethodNode.fullName) + val lambdaTypeDeclNode = + NewTypeDecl() + .fullName(lambdaMethodNode.fullName) + .name(lambdaMethodNode.name) + .inheritsFromTypeFullName(inheritsFromTypeFullName) + scope.addLocalDecl(Ast(lambdaTypeDeclNode)) + + lambdaTypeDeclNode + end createAndPushLambdaTypeDecl + + private def getLambdaImplementedInfo( + expr: LambdaExpr, + expectedType: ExpectedType + ): LambdaImplementedInfo = + val maybeImplementedType = + val maybeResolved = tryWithSafeStackOverflow(expr.calculateResolvedType()) + maybeResolved.toOption + .orElse(expectedType.resolvedType) + .collect { case refType: ResolvedReferenceType => refType } + + val maybeImplementedInterface = maybeImplementedType.flatMap(_.getTypeDeclaration.toScala) + + if maybeImplementedInterface.isEmpty then + val location = s"$filename:${line(expr)}:${column(expr)}" + logger.debug( + s"Could not resolve the interface implemented by a lambda. Type info may be missing: $location. Type info may be missing." + ) + + val maybeBoundMethod = maybeImplementedInterface.flatMap { interface => + interface.getDeclaredMethods.asScala + .filter(_.isAbstract) + .filterNot { method => + // Filter out java.lang.Object methods re-declared by the interface as these are also considered abstract. + // See https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for details. + Try(method.getSignature) match + case Success(signature) => ObjectMethodSignatures.contains(signature) + case Failure(_) => + false // If the signature could not be calculated, it's probably not a standard object method. + } + .headOption } - val typeName = resolvedType.flatMap(typeInfoCalc.fullName) - ExpectedType(typeName, resolvedType) - } - .getOrElse(ExpectedType.empty) - } - - private def dispatchTypeForCall(maybeDecl: Try[ResolvedMethodDeclaration], maybeScope: Option[Expression]): String = { - maybeScope match { - case Some(_: SuperExpr) => - DispatchTypes.STATIC_DISPATCH - case _ => - maybeDecl match { - case Success(decl) => - if (decl.isStatic) DispatchTypes.STATIC_DISPATCH else DispatchTypes.DYNAMIC_DISPATCH - - case _ => - DispatchTypes.DYNAMIC_DISPATCH + LambdaImplementedInfo(maybeImplementedType, maybeBoundMethod) + end getLambdaImplementedInfo + + // TODO: All of this will be thrown out, probably + private def astForLambdaExpr(expr: LambdaExpr, expectedType: ExpectedType): Ast = + scope.pushMethodScope(NewMethod(), expectedType) + + val lambdaMethodName = nextLambdaName() + + val closureBindingsForCapturedVars = closureBindingsForCapturedNodes(lambdaMethodName) + val localsForCaptured = localsForCapturedNodes(closureBindingsForCapturedVars) + val implementedInfo = getLambdaImplementedInfo(expr, expectedType) + val lambdaMethodNode = + createAndPushLambdaMethod( + expr, + lambdaMethodName, + implementedInfo, + localsForCaptured, + expectedType + ) + val typeNameLookup = + lambdaMethodNode.fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + val methodRef = + NewMethodRef() + .methodFullName(lambdaMethodNode.fullName) + .typeFullName(lambdaMethodNode.fullName) + .code(lambdaMethodNode.fullName) + .dynamicTypeHintFullName(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).toSeq) + + addClosureBindingsToDiffGraph(closureBindingsForCapturedVars, methodRef) + + val interfaceBinding = implementedInfo.implementedMethod.map { implementedMethod => + newBindingNode( + implementedMethod.getName, + lambdaMethodNode.signature, + lambdaMethodNode.fullName + ) } - } - } - - private def targetTypeForCall(callExpr: MethodCallExpr): Option[String] = { - val maybeType = callExpr.getScope.toScala match { - case Some(callScope: ThisExpr) => - expressionReturnTypeFullName(callScope) - .orElse(scope.enclosingTypeDeclFullName) - - case Some(callScope: SuperExpr) => - expressionReturnTypeFullName(callScope) - .orElse(scope.enclosingTypeDecl.flatMap(_.inheritsFromTypeFullName.headOption)) - - case Some(scope) => expressionReturnTypeFullName(scope) - - case None => - tryWithSafeStackOverflow(callExpr.resolve()).toOption - .flatMap { methodDeclOption => - if (methodDeclOption.isStatic) typeInfoCalc.fullName(methodDeclOption.declaringType()) - else scope.enclosingTypeDeclFullName - } - .orElse(scope.enclosingTypeDeclFullName) - } - - maybeType.map(typeInfoCalc.registerType) - } - - private def argAstsForCall( - call: Node, - tryResolvedDecl: Try[ResolvedMethodLikeDeclaration], - args: NodeList[Expression] - ): Seq[Ast] = { - val hasVariadicParameter = tryResolvedDecl.map(_.hasVariadicParameter).getOrElse(false) - val paramCount = tryResolvedDecl.map(_.getNumberOfParams).getOrElse(-1) - - val argsAsts = args.asScala.zipWithIndex.flatMap { case (arg, idx) => - val expectedType = getExpectedParamType(tryResolvedDecl, idx) - astsForExpression(arg, expectedType) - }.toList - - tryResolvedDecl match { - case Success(_) if hasVariadicParameter => - val expectedVariadicTypeFullName = getExpectedParamType(tryResolvedDecl, paramCount - 1).fullName - val (regularArgs, varargs) = argsAsts.splitAt(paramCount - 1) - val arrayInitializer = newOperatorCallNode( - Operators.arrayInitializer, - Operators.arrayInitializer, - expectedVariadicTypeFullName, - line(call), - column(call) + + val bindingTable = getLambdaBindingTable( + LambdaBindingInfo( + lambdaMethodNode.fullName, + implementedInfo.implementedInterface, + interfaceBinding + ) ) - val arrayInitializerAst = callAst(arrayInitializer, varargs) - - regularArgs ++ Seq(arrayInitializerAst) - - case _ => argsAsts - } - } - - private def getArgumentCodeString(args: NodeList[Expression]): String = { - args.asScala - .map { - case _: LambdaExpr => "" - case other => other.toString - } - .mkString(", ") - } - - private def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = { - val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve()) - val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments) - - val expressionTypeFullName = expressionReturnTypeFullName(call).orElse(expectedReturnType.fullName) - - val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall) - val returnType = maybeResolvedCall - .map { resolvedCall => - typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty()) - } - .toOption - .flatten - .orElse(expressionTypeFullName) - val dispatchType = dispatchTypeForCall(maybeResolvedCall, call.getScope.toScala) - - val receiverTypeOption = targetTypeForCall(call) - val scopeAsts = call.getScope.toScala match { - case Some(scope) => astsForExpression(scope, ExpectedType(receiverTypeOption)) - - case None => - val objectNode = - createObjectNode(receiverTypeOption, call, dispatchType) - for { - obj <- objectNode - thisParam <- scope.lookupVariable(NameConstants.This).variableNode - } diffGraph.addEdge(obj, thisParam, EdgeTypes.REF) - objectNode.map(Ast(_)).toList - } - - val receiverType = receiverTypeOption.orElse(scopeAsts.rootType).filter(_ != TypeConstants.Any) - - val argumentsCode = getArgumentCodeString(call.getArguments) - val codePrefix = codePrefixForMethodCall(call) - val callCode = s"$codePrefix${call.getNameAsString}($argumentsCode)" - - val callName = call.getNameAsString - val namespace = receiverType.getOrElse(Defines.UnresolvedNamespace) - val signature = composeSignature(returnType, argumentTypes, argumentAsts.size) - val methodFullName = composeMethodFullName(namespace, callName, signature) - val typeFullNameStr = expressionTypeFullName.getOrElse(TypeConstants.Any) - val callRoot = NewCall() - .name(callName) - .methodFullName(methodFullName) - .signature(signature) - .code(callCode) - .dispatchType(dispatchType) - .lineNumber(line(call)) - .columnNumber(column(call)) - .typeFullName(typeFullNameStr) - callRoot.dynamicTypeHintFullName( - packagesJarMappings - .getOrElse(methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString("."), mutable.Set.empty) - .toSeq - ) - callAst(callRoot, argumentAsts, scopeAsts.headOption) - } - - private def astForSuperExpr(superExpr: SuperExpr, expectedType: ExpectedType): Ast = { - val typeFullName = - expressionReturnTypeFullName(superExpr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - typeInfoCalc.registerType(typeFullName) - - val identifier = identifierNode(superExpr, NameConstants.This, NameConstants.Super, typeFullName) - Ast(identifier) - } - - private def astsForParameterList(parameters: NodeList[Parameter]): Seq[Ast] = { - parameters.asScala.toList.zipWithIndex.map { case (param, idx) => - astForParameter(param, idx + 1) - } - } - - private def astForParameter(parameter: Parameter, childNum: Int): Ast = { - val maybeArraySuffix = if (parameter.isVarArgs) "[]" else "" - val typeFullName = - typeInfoCalc - .fullName(parameter.getType) - .orElse(scope.lookupType(parameter.getTypeAsString)) - .map(_ ++ maybeArraySuffix) - .getOrElse(s"${Defines.UnresolvedNamespace}.${parameter.getTypeAsString}") - val evalStrat = - if (parameter.getType.isPrimitiveType) EvaluationStrategies.BY_VALUE else EvaluationStrategies.BY_SHARING - typeInfoCalc.registerType(typeFullName) - val parameterNode = NewMethodParameterIn() - .name(parameter.getName.toString) - .code(parameter.toString) - .lineNumber(line(parameter)) - .columnNumber(column(parameter)) - .evaluationStrategy(evalStrat) - .typeFullName(typeFullName) - .index(childNum) - .order(childNum) - val annotationAsts = parameter.getAnnotations.asScala.map(astForAnnotationExpr) - val ast = Ast(parameterNode) - - scope.addParameter(parameterNode) - - ast.withChildren(annotationAsts) - } + val lambdaTypeDeclNode = createAndPushLambdaTypeDecl(lambdaMethodNode, implementedInfo) + createBindingNodes(lambdaTypeDeclNode, bindingTable) + + scope.popScope() + Ast(methodRef) + end astForLambdaExpr + + private def astForLiteralExpr(expr: LiteralExpr): Ast = + val typeFullName = expressionReturnTypeFullName(expr).getOrElse(TypeConstants.Any) + val literalNode = + NewLiteral() + .code(expr.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + .typeFullName(typeFullName) + Ast(literalNode) + + private def getExpectedParamType( + maybeResolvedCall: Try[ResolvedMethodLikeDeclaration], + idx: Int + ): ExpectedType = + maybeResolvedCall.toOption + .map { methodDecl => + val paramCount = methodDecl.getNumberOfParams + + val resolvedType = if idx < paramCount then + Some(methodDecl.getParam(idx).getType) + else if paramCount > 0 && methodDecl.getParam(paramCount - 1).isVariadic then + Some(methodDecl.getParam(paramCount - 1).getType) + else + None + + val typeName = resolvedType.flatMap(typeInfoCalc.fullName) + ExpectedType(typeName, resolvedType) + } + .getOrElse(ExpectedType.empty) + + private def dispatchTypeForCall( + maybeDecl: Try[ResolvedMethodDeclaration], + maybeScope: Option[Expression] + ): String = + maybeScope match + case Some(_: SuperExpr) => + DispatchTypes.STATIC_DISPATCH + case _ => + maybeDecl match + case Success(decl) => + if decl.isStatic then DispatchTypes.STATIC_DISPATCH + else DispatchTypes.DYNAMIC_DISPATCH + + case _ => + DispatchTypes.DYNAMIC_DISPATCH + + private def targetTypeForCall(callExpr: MethodCallExpr): Option[String] = + val maybeType = callExpr.getScope.toScala match + case Some(callScope: ThisExpr) => + expressionReturnTypeFullName(callScope) + .orElse(scope.enclosingTypeDeclFullName) + + case Some(callScope: SuperExpr) => + expressionReturnTypeFullName(callScope) + .orElse(scope.enclosingTypeDecl.flatMap(_.inheritsFromTypeFullName.headOption)) + + case Some(scope) => expressionReturnTypeFullName(scope) + + case None => + tryWithSafeStackOverflow(callExpr.resolve()).toOption + .flatMap { methodDeclOption => + if methodDeclOption.isStatic then + typeInfoCalc.fullName(methodDeclOption.declaringType()) + else scope.enclosingTypeDeclFullName + } + .orElse(scope.enclosingTypeDeclFullName) + + maybeType.map(typeInfoCalc.registerType) + end targetTypeForCall + + private def argAstsForCall( + call: Node, + tryResolvedDecl: Try[ResolvedMethodLikeDeclaration], + args: NodeList[Expression] + ): Seq[Ast] = + val hasVariadicParameter = tryResolvedDecl.map(_.hasVariadicParameter).getOrElse(false) + val paramCount = tryResolvedDecl.map(_.getNumberOfParams).getOrElse(-1) + + val argsAsts = args.asScala.zipWithIndex.flatMap { case (arg, idx) => + val expectedType = getExpectedParamType(tryResolvedDecl, idx) + astsForExpression(arg, expectedType) + }.toList + + tryResolvedDecl match + case Success(_) if hasVariadicParameter => + val expectedVariadicTypeFullName = + getExpectedParamType(tryResolvedDecl, paramCount - 1).fullName + val (regularArgs, varargs) = argsAsts.splitAt(paramCount - 1) + val arrayInitializer = newOperatorCallNode( + Operators.arrayInitializer, + Operators.arrayInitializer, + expectedVariadicTypeFullName, + line(call), + column(call) + ) + + val arrayInitializerAst = callAst(arrayInitializer, varargs) + + regularArgs ++ Seq(arrayInitializerAst) + + case _ => argsAsts + end argAstsForCall + + private def getArgumentCodeString(args: NodeList[Expression]): String = + args.asScala + .map { + case _: LambdaExpr => "" + case other => other.toString + } + .mkString(", ") -} + private def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = + val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve()) + val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments) + + val expressionTypeFullName = + expressionReturnTypeFullName(call).orElse(expectedReturnType.fullName) + + val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall) + val returnType = maybeResolvedCall + .map { resolvedCall => + typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty()) + } + .toOption + .flatten + .orElse(expressionTypeFullName) + val dispatchType = dispatchTypeForCall(maybeResolvedCall, call.getScope.toScala) + + val receiverTypeOption = targetTypeForCall(call) + val scopeAsts = call.getScope.toScala match + case Some(scope) => astsForExpression(scope, ExpectedType(receiverTypeOption)) + + case None => + val objectNode = + createObjectNode(receiverTypeOption, call, dispatchType) + for + obj <- objectNode + thisParam <- scope.lookupVariable(NameConstants.This).variableNode + do diffGraph.addEdge(obj, thisParam, EdgeTypes.REF) + objectNode.map(Ast(_)).toList + + val receiverType = + receiverTypeOption.orElse(scopeAsts.rootType).filter(_ != TypeConstants.Any) + + val argumentsCode = getArgumentCodeString(call.getArguments) + val codePrefix = codePrefixForMethodCall(call) + val callCode = s"$codePrefix${call.getNameAsString}($argumentsCode)" + + val callName = call.getNameAsString + val namespace = receiverType.getOrElse(Defines.UnresolvedNamespace) + val signature = composeSignature(returnType, argumentTypes, argumentAsts.size) + val methodFullName = composeMethodFullName(namespace, callName, signature) + val typeFullNameStr = expressionTypeFullName.getOrElse(TypeConstants.Any) + val callRoot = NewCall() + .name(callName) + .methodFullName(methodFullName) + .signature(signature) + .code(callCode) + .dispatchType(dispatchType) + .lineNumber(line(call)) + .columnNumber(column(call)) + .typeFullName(typeFullNameStr) + callRoot.dynamicTypeHintFullName( + packagesJarMappings + .getOrElse( + methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString("."), + mutable.Set.empty + ) + .toSeq + ) + callAst(callRoot, argumentAsts, scopeAsts.headOption) + end astForMethodCall + + private def astForSuperExpr(superExpr: SuperExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(superExpr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + typeInfoCalc.registerType(typeFullName) + + val identifier = + identifierNode(superExpr, NameConstants.This, NameConstants.Super, typeFullName) + Ast(identifier) + + private def astsForParameterList(parameters: NodeList[Parameter]): Seq[Ast] = + parameters.asScala.toList.zipWithIndex.map { case (param, idx) => + astForParameter(param, idx + 1) + } + + private def astForParameter(parameter: Parameter, childNum: Int): Ast = + val maybeArraySuffix = if parameter.isVarArgs then "[]" else "" + val typeFullName = + typeInfoCalc + .fullName(parameter.getType) + .orElse(scope.lookupType(parameter.getTypeAsString)) + .map(_ ++ maybeArraySuffix) + .getOrElse(s"${Defines.UnresolvedNamespace}.${parameter.getTypeAsString}") + val evalStrat = + if parameter.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE + else EvaluationStrategies.BY_SHARING + typeInfoCalc.registerType(typeFullName) + val parameterNode = NewMethodParameterIn() + .name(parameter.getName.toString) + .code(parameter.toString) + .lineNumber(line(parameter)) + .columnNumber(column(parameter)) + .evaluationStrategy(evalStrat) + .typeFullName(typeFullName) + .index(childNum) + .order(childNum) + val annotationAsts = parameter.getAnnotations.asScala.map(astForAnnotationExpr) + val ast = Ast(parameterNode) + + scope.addParameter(parameterNode) + + ast.withChildren(annotationAsts) + end astForParameter +end AstCreator diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala index 9d2fadee..50898871 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala @@ -4,46 +4,44 @@ import better.files.File import io.appthreat.x2cpg.passes.frontend.XConfigFileCreationPass import io.shiftleft.codepropertygraph.Cpg -class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg) { +class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // JAVA_INTERNAL - extensionFilter(".properties"), - // HTML - pathRegexFilter(".*resources/templates.*.html"), - // JSP - extensionFilter(".jsp"), - // Velocity files, see https://velocity.apache.org - extensionFilter(".vm"), - // For Terraform secrets - extensionFilter(".tf"), - extensionFilter(".tfvars"), - // PLAY - pathEndFilter("routes"), - pathEndFilter("application.conf"), - // SERVLET - pathEndFilter("web.xml"), - // JSF - pathEndFilter("faces-config.xml"), - // STRUTS - pathEndFilter("struts.xml"), - // DIRECT WEB REMOTING - pathEndFilter("dwr.xml"), - // MYBATIS - mybatisFilter, - // BUILD SYSTEM - pathEndFilter("build.gradle"), - pathEndFilter("build.gradle.kts"), - // ANDROID - pathEndFilter("AndroidManifest.xml"), - // Bom - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json") - ) + override val configFileFilters: List[File => Boolean] = List( + // JAVA_INTERNAL + extensionFilter(".properties"), + // HTML + pathRegexFilter(".*resources/templates.*.html"), + // JSP + extensionFilter(".jsp"), + // Velocity files, see https://velocity.apache.org + extensionFilter(".vm"), + // For Terraform secrets + extensionFilter(".tf"), + extensionFilter(".tfvars"), + // PLAY + pathEndFilter("routes"), + pathEndFilter("application.conf"), + // SERVLET + pathEndFilter("web.xml"), + // JSF + pathEndFilter("faces-config.xml"), + // STRUTS + pathEndFilter("struts.xml"), + // DIRECT WEB REMOTING + pathEndFilter("dwr.xml"), + // MYBATIS + mybatisFilter, + // BUILD SYSTEM + pathEndFilter("build.gradle"), + pathEndFilter("build.gradle.kts"), + // ANDROID + pathEndFilter("AndroidManifest.xml"), + // Bom + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json") + ) - private def mybatisFilter(file: File): Boolean = { - file.canonicalPath.contains("batis") && file.extension.contains(".xml") - } - -} + private def mybatisFilter(file: File): Boolean = + file.canonicalPath.contains("batis") && file.extension.contains(".xml") +end ConfigFileCreationPass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala index 69f6c76e..afbe9725 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala @@ -4,18 +4,17 @@ import io.appthreat.x2cpg.Defines import io.appthreat.x2cpg.passes.frontend.XTypeHintCallLinker import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.util.regex.Pattern -class JavaTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { +class JavaTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override protected def calls: Iterator[Call] = { - cpg.call - .nameNot(".*", ".*") - .filter(c => - calleeNames(c).nonEmpty && c.callee.fullNameNot(Pattern.quote(Defines.UnresolvedNamespace) + ".*").isEmpty - ) - } - -} + override protected def calls: Iterator[Call] = + cpg.call + .nameNot(".*", ".*") + .filter(c => + calleeNames(c).nonEmpty && c.callee.fullNameNot( + Pattern.quote(Defines.UnresolvedNamespace) + ".*" + ).isEmpty + ) diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala index d439b52f..7499821a 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala @@ -1,70 +1,71 @@ package io.appthreat.javasrc2cpg.passes import io.appthreat.x2cpg.Defines -import io.appthreat.x2cpg.passes.frontend._ +import io.appthreat.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate.DiffGraphBuilder class JavaTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[Method](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[Method] = - new JavaTypeRecovery(cpg, state) -} + extends XTypeRecoveryPass[Method](cpg, config): + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[Method] = + new JavaTypeRecovery(cpg, state) -private class JavaTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[Method](cpg, state) { +private class JavaTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) + extends XTypeRecovery[Method](cpg, state): - override def compilationUnit: Iterator[Method] = cpg.method.isExternal(false).iterator + override def compilationUnit: Iterator[Method] = cpg.method.isExternal(false).iterator - override def generateRecoveryForCompilationUnitTask( - unit: Method, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[Method] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForJavaFile(cpg, unit, builder, state.copy(config = newConfig)) - } -} + override def generateRecoveryForCompilationUnitTask( + unit: Method, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[Method] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForJavaFile(cpg, unit, builder, state.copy(config = newConfig)) -private class RecoverForJavaFile(cpg: Cpg, cu: Method, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[Method](cpg, cu, builder, state) { +private class RecoverForJavaFile( + cpg: Cpg, + cu: Method, + builder: DiffGraphBuilder, + state: XTypeRecoveryState +) extends RecoverForXCompilationUnit[Method](cpg, cu, builder, state): - private def javaNodeToLocalKey(n: AstNode): Option[LocalKey] = n match { - case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) - case _ => SBKey.fromNodeToLocalKey(n) - } + private def javaNodeToLocalKey(n: AstNode): Option[LocalKey] = n match + case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) + case _ => SBKey.fromNodeToLocalKey(n) - override protected val symbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) + override protected val symbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) - override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) + override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) - override protected def isConstructor(name: String): Boolean = !name.isBlank && name.charAt(0).isUpper + override protected def isConstructor(name: String): Boolean = + !name.isBlank && name.charAt(0).isUpper - override protected def postVisitImports(): Unit = { - symbolTable.view.foreach { case (k, ts) => - val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) - if (tss.isEmpty) - symbolTable.remove(k) - else - symbolTable.put(k, tss) - } - } + override protected def postVisitImports(): Unit = + symbolTable.view.foreach { case (k, ts) => + val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) + if tss.isEmpty then + symbolTable.remove(k) + else + symbolTable.put(k, tss) + } - // There seems to be issues with inferring these, often due to situations where super and this are confused on name - // and code properties. - override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = if (i.name != "this") { - super.storeIdentifierTypeInfo(i, types) - } + // There seems to be issues with inferring these, often due to situations where super and this are confused on name + // and code properties. + override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + if i.name != "this" then + super.storeIdentifierTypeInfo(i, types) - override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = - if (types.nonEmpty) { - state.changesWereMade.compareAndSet(false, true) - val signedTypes = types.map { - case t if t.endsWith(c.signature) => t - case t => s"$t:${c.signature}" - } - builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) - } - -} + override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = + if types.nonEmpty then + state.changesWereMade.compareAndSet(false, true) + val signedTypes = types.map { + case t if t.endsWith(c.signature) => t + case t => s"$t:${c.signature}" + } + builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) +end RecoverForJavaFile diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala index 95b02ab2..f9d84d28 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala @@ -7,7 +7,7 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.ModifierTypes import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} import io.shiftleft.passes.ConcurrentWriterCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.LoggerFactory import scala.jdk.OptionConverters.RichOptional @@ -15,96 +15,106 @@ import io.appthreat.x2cpg.Defines.UnresolvedNamespace import io.shiftleft.codepropertygraph.generated.nodes.Call.PropertyNames import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants -class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg) { - - private val cache = new GuavaCache(CacheBuilder.newBuilder().build[String, Option[Method]]()) - private val resolvedMethodIndex = cpg.method - .filterNot(_.fullName.startsWith(Defines.UnresolvedNamespace)) - .filterNot(_.signature.startsWith(Defines.UnresolvedSignature)) - .groupBy(_.name) - - private case class NameParts(typeDecl: Option[String], signature: String) - - override def generateParts(): Array[Call] = { - cpg.call - .filter(_.signature.startsWith(Defines.UnresolvedSignature)) - .filterNot { _.name.startsWith(UnresolvedNamespace) } - .toArray - } - - private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = { - // An erroneous `this` argument is added for unresolved calls to static methods. - val argSizeMod = if (method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC)) 1 else 0 - lazy val methodNameParts = getNameParts(method.name, method.fullName) - - val parameterSizesMatch = - (method.parameter.size == (call.argument.size - argSizeMod)) - - lazy val argTypesMatch = doArgumentTypesMatch(method: Method, call: Call, skipCallThis = argSizeMod == 1) - - lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) - - parameterSizesMatch && argTypesMatch && typeDeclMatches - } - - /** Check if argument types match by comparing exact full names. An argument type of `ANY` always matches. - * - * TODO: Take inheritance hierarchies into account - */ - private def doArgumentTypesMatch(method: Method, call: Call, skipCallThis: Boolean): Boolean = { - val callArgs = if (skipCallThis) call.argument.toList.tail else call.argument.toList - - val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => - val maybeArgumentType = Option(argument.property(PropertyNames.TypeFullName)) - .map(_.toString()) - .getOrElse(TypeConstants.Any) - - val argMatches = maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName - - !argMatches - } - - !hasDifferingArg - } - - private def getNameParts(name: String, fullName: String): NameParts = { - val Array(qualifiedName, signature) = fullName.split(":", 2) - - val typeDeclName = qualifiedName.stripSuffix(name) match { - case "" => None - - case typeDeclName => Some(typeDeclName) - } - - NameParts(typeDeclName, signature) - } - - private def getReplacementMethod(call: Call): Option[Method] = { - val argTypes = - call.argument.flatMap(arg => Option(arg.property(PropertyNames.TypeFullName)).map(_.toString)).mkString(":") - val callKey = - s"${call.methodFullName}:$argTypes" - cache.get(callKey).toScala.getOrElse { - val callNameParts = getNameParts(call.name, call.methodFullName) - resolvedMethodIndex.get(call.name).flatMap { candidateMethods => - val candidateMethodsIter = candidateMethods.iterator - val uniqueMatchingMethod = - candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)).flatMap { method => - val otherMatchingMethod = candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)) - // Only return a resulting method if exactly one matching method is found. - Option.when(otherMatchingMethod.isEmpty)(method) - } - cache.put(callKey, uniqueMatchingMethod) - uniqueMatchingMethod - } - } - } - - override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = { - getReplacementMethod(call).foreach { replacementMethod => - diffGraph.setNodeProperty(call, PropertyNames.MethodFullName, replacementMethod.fullName) - diffGraph.setNodeProperty(call, PropertyNames.Signature, replacementMethod.signature) - diffGraph.setNodeProperty(call, PropertyNames.TypeFullName, replacementMethod.methodReturn.typeFullName) - } - } -} +class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg): + + private val cache = new GuavaCache(CacheBuilder.newBuilder().build[String, Option[Method]]()) + private val resolvedMethodIndex = cpg.method + .filterNot(_.fullName.startsWith(Defines.UnresolvedNamespace)) + .filterNot(_.signature.startsWith(Defines.UnresolvedSignature)) + .groupBy(_.name) + + private case class NameParts(typeDecl: Option[String], signature: String) + + override def generateParts(): Array[Call] = + cpg.call + .filter(_.signature.startsWith(Defines.UnresolvedSignature)) + .filterNot { _.name.startsWith(UnresolvedNamespace) } + .toArray + + private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = + // An erroneous `this` argument is added for unresolved calls to static methods. + val argSizeMod = + if method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC) then 1 else 0 + lazy val methodNameParts = getNameParts(method.name, method.fullName) + + val parameterSizesMatch = + (method.parameter.size == (call.argument.size - argSizeMod)) + + lazy val argTypesMatch = + doArgumentTypesMatch(method: Method, call: Call, skipCallThis = argSizeMod == 1) + + lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) + + parameterSizesMatch && argTypesMatch && typeDeclMatches + + /** Check if argument types match by comparing exact full names. An argument type of `ANY` + * always matches. + * + * TODO: Take inheritance hierarchies into account + */ + private def doArgumentTypesMatch(method: Method, call: Call, skipCallThis: Boolean): Boolean = + val callArgs = if skipCallThis then call.argument.toList.tail else call.argument.toList + + val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => + val maybeArgumentType = Option(argument.property(PropertyNames.TypeFullName)) + .map(_.toString()) + .getOrElse(TypeConstants.Any) + + val argMatches = + maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName + + !argMatches + } + + !hasDifferingArg + + private def getNameParts(name: String, fullName: String): NameParts = + val Array(qualifiedName, signature) = fullName.split(":", 2) + + val typeDeclName = qualifiedName.stripSuffix(name) match + case "" => None + + case typeDeclName => Some(typeDeclName) + + NameParts(typeDeclName, signature) + + private def getReplacementMethod(call: Call): Option[Method] = + val argTypes = + call.argument.flatMap(arg => + Option(arg.property(PropertyNames.TypeFullName)).map(_.toString) + ).mkString(":") + val callKey = + s"${call.methodFullName}:$argTypes" + cache.get(callKey).toScala.getOrElse { + val callNameParts = getNameParts(call.name, call.methodFullName) + resolvedMethodIndex.get(call.name).flatMap { candidateMethods => + val candidateMethodsIter = candidateMethods.iterator + val uniqueMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)).flatMap { + method => + val otherMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)) + // Only return a resulting method if exactly one matching method is found. + Option.when(otherMatchingMethod.isEmpty)(method) + } + cache.put(callKey, uniqueMatchingMethod) + uniqueMatchingMethod + } + } + end getReplacementMethod + + override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = + getReplacementMethod(call).foreach { replacementMethod => + diffGraph.setNodeProperty( + call, + PropertyNames.MethodFullName, + replacementMethod.fullName + ) + diffGraph.setNodeProperty(call, PropertyNames.Signature, replacementMethod.signature) + diffGraph.setNodeProperty( + call, + PropertyNames.TypeFullName, + replacementMethod.methodReturn.typeFullName + ) + } +end TypeInferencePass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala index ba0e8739..e64b58c6 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala @@ -5,93 +5,79 @@ import io.appthreat.javasrc2cpg.scope.Scope.ScopeVariable import Scope.* import JavaScopeElement.* import io.shiftleft.codepropertygraph.generated.nodes.{ - NewImport, - NewMethod, - NewNamespaceBlock, - NewTypeDecl, - NewTypeParameter + NewImport, + NewMethod, + NewNamespaceBlock, + NewTypeDecl, + NewTypeParameter } import scala.collection.mutable import io.appthreat.x2cpg.Ast import io.shiftleft.codepropertygraph.generated.nodes.AstNodeNew -trait JavaScopeElement { - private val variables = mutable.Map[String, ScopeVariable]() - private val types = mutable.Map[String, String]() - private var wildcardImports: WildcardImports = NoWildcard +trait JavaScopeElement: + private val variables = mutable.Map[String, ScopeVariable]() + private val types = mutable.Map[String, String]() + private var wildcardImports: WildcardImports = NoWildcard - def addVariableToScope(variable: ScopeVariable): Unit = { - variables.put(variable.name, variable) - } + def addVariableToScope(variable: ScopeVariable): Unit = + variables.put(variable.name, variable) - def lookupVariable(name: String): Option[ScopeVariable] = { - variables.get(name) - } + def lookupVariable(name: String): Option[ScopeVariable] = + variables.get(name) - def addTypeToScope(name: String, typeFullName: String): Unit = { - types.put(name, typeFullName) - } + def addTypeToScope(name: String, typeFullName: String): Unit = + types.put(name, typeFullName) - def lookupType(name: String, includeWildcards: Boolean): Option[String] = { - types.get(name) match { - case None if includeWildcards => getNameWithWildcardPrefix(name) - case result => result - } - } + def lookupType(name: String, includeWildcards: Boolean): Option[String] = + types.get(name) match + case None if includeWildcards => getNameWithWildcardPrefix(name) + case result => result - def getNameWithWildcardPrefix(name: String): Option[String] = { - wildcardImports match { - case SingleWildcard(prefix) => Some(s"$prefix.$name") + def getNameWithWildcardPrefix(name: String): Option[String] = + wildcardImports match + case SingleWildcard(prefix) => Some(s"$prefix.$name") - case _ => None - } - } + case _ => None - def addWildcardImport(prefix: String): Unit = { - wildcardImports match { - case NoWildcard => wildcardImports = SingleWildcard(prefix) + def addWildcardImport(prefix: String): Unit = + wildcardImports match + case NoWildcard => wildcardImports = SingleWildcard(prefix) - case SingleWildcard(_) => wildcardImports = MultipleWildcards + case SingleWildcard(_) => wildcardImports = MultipleWildcards - case MultipleWildcards => // Already MultipleWildcards, so change nothing - } - } + case MultipleWildcards => // Already MultipleWildcards, so change nothing + // TODO: Refactor and remove this + def getVariables(): List[ScopeVariable] = variables.values.toList +end JavaScopeElement - // TODO: Refactor and remove this - def getVariables(): List[ScopeVariable] = variables.values.toList -} - -private object JavaScopeElement { - sealed trait WildcardImports - case object NoWildcard extends WildcardImports - case class SingleWildcard(prefix: String) extends WildcardImports - case object MultipleWildcards extends WildcardImports -} +private object JavaScopeElement: + sealed trait WildcardImports + case object NoWildcard extends WildcardImports + case class SingleWildcard(prefix: String) extends WildcardImports + case object MultipleWildcards extends WildcardImports -class NamespaceScope(val namespace: NewNamespaceBlock) extends JavaScopeElement with TypeDeclContainer +class NamespaceScope(val namespace: NewNamespaceBlock) extends JavaScopeElement + with TypeDeclContainer class BlockScope extends JavaScopeElement class MethodScope(val method: NewMethod, val returnType: ExpectedType) extends JavaScopeElement -class TypeDeclScope(val typeDecl: NewTypeDecl) extends JavaScopeElement with TypeDeclContainer { - private val memberInitializers = mutable.ListBuffer[Ast]() - - // TODO: Refactor and remove this. - def addMemberInitializers(initializers: Seq[Ast]): Unit = { - memberInitializers.appendAll(initializers) - } - - // TODO: Refactor and remove this. - def getMemberInitializers(): List[Ast] = { - memberInitializers.toList.flatMap { ast => - ast.root match { - case Some(root: AstNodeNew) => - Some(ast.subTreeCopy(root)) - - case _ => None - } - } - } -} +class TypeDeclScope(val typeDecl: NewTypeDecl) extends JavaScopeElement with TypeDeclContainer: + private val memberInitializers = mutable.ListBuffer[Ast]() + + // TODO: Refactor and remove this. + def addMemberInitializers(initializers: Seq[Ast]): Unit = + memberInitializers.appendAll(initializers) + + // TODO: Refactor and remove this. + def getMemberInitializers(): List[Ast] = + memberInitializers.toList.flatMap { ast => + ast.root match + case Some(root: AstNodeNew) => + Some(ast.subTreeCopy(root)) + + case _ => None + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala index b23a575b..3ea7aa2d 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala @@ -13,11 +13,11 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewMember import io.shiftleft.codepropertygraph.generated.nodes.NewNode import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew -import Scope._ +import Scope.* import io.appthreat.javasrc2cpg.util.NameConstants import scala.collection.mutable -import io.appthreat.x2cpg.utils.ListUtils._ +import io.appthreat.x2cpg.utils.ListUtils.* import io.shiftleft.codepropertygraph.generated.nodes.NewNamespaceBlock import io.appthreat.x2cpg.Ast @@ -31,223 +31,214 @@ case class NodeTypeInfo( isField: Boolean = false, isStatic: Boolean = false ) -class Scope { - private var scopeStack: List[JavaScopeElement] = Nil - - def pushBlockScope(): Unit = { - scopeStack = new BlockScope() :: scopeStack - } - - def pushMethodScope(method: NewMethod, returnType: ExpectedType): Unit = { - scopeStack = new MethodScope(method, returnType) :: scopeStack - } - - def pushTypeDeclScope(typeDecl: NewTypeDecl): Unit = { - scopeStack = new TypeDeclScope(typeDecl) :: scopeStack - } - - def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = { - scopeStack = new NamespaceScope(namespace) :: scopeStack - } - - def popScope(): JavaScopeElement = { - val scope = scopeStack.head - scopeStack = scopeStack.tail - scope - } - - def addParameter(parameter: NewMethodParameterIn): Unit = { - addVariable(ScopeParameter(parameter)) - } - - def addLocal(local: NewLocal): Unit = { - addVariable(ScopeLocal(local)) - } - - def addMember(member: NewMember, isStatic: Boolean): Unit = { - addVariable(ScopeMember(member, isStatic)) - } - - def addStaticImport(importNode: NewImport): Unit = { - addVariable(ScopeStaticImport(importNode)) - } - - private def addVariable(variable: ScopeVariable): Unit = { - scopeStack.head.addVariableToScope(variable) - } - - def addType(name: String, typeFullName: String): Unit = { - scopeStack.head.addTypeToScope(name, typeFullName) - } - - def addWildcardImport(prefix: String): Unit = { - scopeStack.head.addWildcardImport(prefix) - } - - def lookupVariable(name: String): VariableLookupResult = { - scopeStack.takeUntil(_.lookupVariable(name).isDefined) match { - case Nil => NotInScope - - case foundSubstack => - // Know the last will contain the lookup variable since that is the condition to terminate - // the takeUntil - val variable = foundSubstack.last.lookupVariable(name).get - lookupResultFromFound(foundSubstack, variable) - } - } - - def lookupType(simpleName: String): Option[String] = { - lookupType(simpleName, includeWildcards = true) - } - - private def lookupType(simpleName: String, includeWildcards: Boolean): Option[String] = { - scopeStack.iterator - .map(_.lookupType(simpleName, includeWildcards)) - .collectFirst { case Some(typeFullName) => typeFullName } - } - - def lookupVariableOrType(name: String): Option[String] = { - lookupVariable(name).typeFullName.orElse(lookupType(name, includeWildcards = false)) - } - - private def lookupResultFromFound( - foundSubstack: List[JavaScopeElement], - variable: ScopeVariable - ): VariableLookupResult = { - val typeDecls = foundSubstack.collect { case td: TypeDeclScope => td } - - val isCaptureChain = typeDecls.size > 1 || (typeDecls.nonEmpty && !foundSubstack.last.isInstanceOf[TypeDeclScope]) - - if (isCaptureChain) { - CapturedVariable(typeDecls.map(_.typeDecl), variable) - } else { - SimpleVariable(variable) - } - } - - def enclosingTypeDecl: Option[NewTypeDecl] = scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => - typeDeclScope.typeDecl - } - - def enclosingTypeDeclFullName: Option[String] = enclosingTypeDecl.map(_.fullName) - - def enclosingMethodReturnType: Option[ExpectedType] = scopeStack.collectFirst { case methodScope: MethodScope => - methodScope.returnType - } - - def addLocalDecl(decl: Ast): Unit = { - scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => typeDeclContainer.registerTypeDecl(decl) } - } - - // TODO: The below section of todos are all methods that have been added for simple compatibility with the old - // scope. The plan is to refactor the code to handle these directly in the AstCreator to make the code easier - // to reason about, so these should be removed when that happens. - - // TODO: Refactor and remove this - def addMemberInitializers(initializers: Seq[Ast]): Unit = { - scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => typeDeclScope.addMemberInitializers(initializers) } - } - - // TODO: Refactor and remove this - def memberInitializers: List[Ast] = { - scopeStack - .collectFirst { case typeDeclScope: TypeDeclScope => typeDeclScope.getMemberInitializers() } - .getOrElse(Nil) - } - - // TODO: Refactor and remove this - def localDeclsInScope: List[Ast] = { - scopeStack - .collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registeredTypeDecls - } - .getOrElse(Nil) - } - - // TODO: Refactor and remove this - def lambdaMethodsInScope: List[Ast] = { - scopeStack - .collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registeredLambdaMethods - } - .getOrElse(Nil) - } - - // TODO: Refactor and remove this - def addLambdaMethod(method: Ast): Unit = { - scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registerLambdaMethod(method) - } - } - - // TODO: Refactor and remove this - def capturedVariables: List[ScopeVariable] = { - scopeStack - .flatMap(_.getVariables()) - .collect { - case local: ScopeLocal => local - - case parameter: ScopeParameter if parameter.name != NameConstants.This => parameter - } - } -} - -object Scope { - type NewScopeNode = NewBlock | NewMethod | NewTypeDecl | NewNamespaceBlock - type NewVariableNode = NewLocal | NewMethodParameterIn | NewMember | NewImport - - sealed trait ScopeVariable { - def node: NewVariableNode - def typeFullName: String - def name: String - } - final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable { - val typeFullName: String = node.typeFullName - val name = node.name - } - final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable { - val typeFullName: String = node.typeFullName - val name = node.name - } - final case class ScopeMember(override val node: NewMember, isStatic: Boolean) extends ScopeVariable { - val typeFullName: String = node.typeFullName - val name = node.name - } - final case class ScopeStaticImport(override val node: NewImport) extends ScopeVariable { - val typeFullName: String = node.importedEntity.get - val name = node.importedAs.get - } - - sealed trait VariableLookupResult { - def typeFullName: Option[String] = None - def variableNode: Option[NewVariableNode] = None - - // TODO: Added for convenience, but when proper capture logic is implemented the found - // variable cases will have to be handled separately which would render this unnecessary. - def getVariable(): Option[ScopeVariable] = None +class Scope: + private var scopeStack: List[JavaScopeElement] = Nil + + def pushBlockScope(): Unit = + scopeStack = new BlockScope() :: scopeStack + + def pushMethodScope(method: NewMethod, returnType: ExpectedType): Unit = + scopeStack = new MethodScope(method, returnType) :: scopeStack + + def pushTypeDeclScope(typeDecl: NewTypeDecl): Unit = + scopeStack = new TypeDeclScope(typeDecl) :: scopeStack + + def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = + scopeStack = new NamespaceScope(namespace) :: scopeStack + + def popScope(): JavaScopeElement = + val scope = scopeStack.head + scopeStack = scopeStack.tail + scope + + def addParameter(parameter: NewMethodParameterIn): Unit = + addVariable(ScopeParameter(parameter)) + + def addLocal(local: NewLocal): Unit = + addVariable(ScopeLocal(local)) + + def addMember(member: NewMember, isStatic: Boolean): Unit = + addVariable(ScopeMember(member, isStatic)) + + def addStaticImport(importNode: NewImport): Unit = + addVariable(ScopeStaticImport(importNode)) + + private def addVariable(variable: ScopeVariable): Unit = + scopeStack.head.addVariableToScope(variable) + + def addType(name: String, typeFullName: String): Unit = + scopeStack.head.addTypeToScope(name, typeFullName) + + def addWildcardImport(prefix: String): Unit = + scopeStack.head.addWildcardImport(prefix) + + def lookupVariable(name: String): VariableLookupResult = + scopeStack.takeUntil(_.lookupVariable(name).isDefined) match + case Nil => NotInScope + + case foundSubstack => + // Know the last will contain the lookup variable since that is the condition to terminate + // the takeUntil + val variable = foundSubstack.last.lookupVariable(name).get + lookupResultFromFound(foundSubstack, variable) + + def lookupType(simpleName: String): Option[String] = + lookupType(simpleName, includeWildcards = true) + + private def lookupType(simpleName: String, includeWildcards: Boolean): Option[String] = + scopeStack.iterator + .map(_.lookupType(simpleName, includeWildcards)) + .collectFirst { case Some(typeFullName) => typeFullName } + + def lookupVariableOrType(name: String): Option[String] = + lookupVariable(name).typeFullName.orElse(lookupType(name, includeWildcards = false)) + + private def lookupResultFromFound( + foundSubstack: List[JavaScopeElement], + variable: ScopeVariable + ): VariableLookupResult = + val typeDecls = foundSubstack.collect { case td: TypeDeclScope => td } + + val isCaptureChain = + typeDecls.size > 1 || (typeDecls.nonEmpty && !foundSubstack.last.isInstanceOf[ + TypeDeclScope + ]) + + if isCaptureChain then + CapturedVariable(typeDecls.map(_.typeDecl), variable) + else + SimpleVariable(variable) + + def enclosingTypeDecl: Option[NewTypeDecl] = + scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.typeDecl + } + + def enclosingTypeDeclFullName: Option[String] = enclosingTypeDecl.map(_.fullName) + + def enclosingMethodReturnType: Option[ExpectedType] = + scopeStack.collectFirst { case methodScope: MethodScope => + methodScope.returnType + } + + def addLocalDecl(decl: Ast): Unit = + scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registerTypeDecl(decl) + } + + // TODO: The below section of todos are all methods that have been added for simple compatibility with the old + // scope. The plan is to refactor the code to handle these directly in the AstCreator to make the code easier + // to reason about, so these should be removed when that happens. + + // TODO: Refactor and remove this + def addMemberInitializers(initializers: Seq[Ast]): Unit = + scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.addMemberInitializers(initializers) + } + + // TODO: Refactor and remove this + def memberInitializers: List[Ast] = + scopeStack + .collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.getMemberInitializers() + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def localDeclsInScope: List[Ast] = + scopeStack + .collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registeredTypeDecls + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def lambdaMethodsInScope: List[Ast] = + scopeStack + .collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registeredLambdaMethods + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def addLambdaMethod(method: Ast): Unit = + scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registerLambdaMethod(method) + } // TODO: Refactor and remove this - def asNodeInfoOption: Option[NodeTypeInfo] = None - } - case object NotInScope extends VariableLookupResult - sealed trait FoundVariable(variable: ScopeVariable) extends VariableLookupResult { - override val typeFullName: Option[String] = Some(variable.typeFullName) - override val variableNode: Option[NewVariableNode] = Some(variable.node) - override def getVariable(): Option[ScopeVariable] = Some(variable) - - override def asNodeInfoOption: Option[NodeTypeInfo] = { - val nodeTypeInfo = variable match { - case ScopeMember(memberNode, isStatic) => - NodeTypeInfo(memberNode, memberNode.name, Some(memberNode.typeFullName), true, isStatic) - - case variable => NodeTypeInfo(variable.node, variable.name, Some(variable.typeFullName), false, false) - } - - Some(nodeTypeInfo) - } - } - final case class SimpleVariable(variable: ScopeVariable) extends FoundVariable(variable) - - final case class CapturedVariable(typeDeclChain: List[NewTypeDecl], variable: ScopeVariable) - extends FoundVariable(variable) -} + def capturedVariables: List[ScopeVariable] = + scopeStack + .flatMap(_.getVariables()) + .collect { + case local: ScopeLocal => local + + case parameter: ScopeParameter if parameter.name != NameConstants.This => parameter + } +end Scope + +object Scope: + type NewScopeNode = NewBlock | NewMethod | NewTypeDecl | NewNamespaceBlock + type NewVariableNode = NewLocal | NewMethodParameterIn | NewMember | NewImport + + sealed trait ScopeVariable: + def node: NewVariableNode + def typeFullName: String + def name: String + final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeMember(override val node: NewMember, isStatic: Boolean) + extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeStaticImport(override val node: NewImport) extends ScopeVariable: + val typeFullName: String = node.importedEntity.get + val name = node.importedAs.get + + sealed trait VariableLookupResult: + def typeFullName: Option[String] = None + def variableNode: Option[NewVariableNode] = None + + // TODO: Added for convenience, but when proper capture logic is implemented the found + // variable cases will have to be handled separately which would render this unnecessary. + def getVariable(): Option[ScopeVariable] = None + + // TODO: Refactor and remove this + def asNodeInfoOption: Option[NodeTypeInfo] = None + case object NotInScope extends VariableLookupResult + sealed trait FoundVariable(variable: ScopeVariable) extends VariableLookupResult: + override val typeFullName: Option[String] = Some(variable.typeFullName) + override val variableNode: Option[NewVariableNode] = Some(variable.node) + override def getVariable(): Option[ScopeVariable] = Some(variable) + + override def asNodeInfoOption: Option[NodeTypeInfo] = + val nodeTypeInfo = variable match + case ScopeMember(memberNode, isStatic) => + NodeTypeInfo( + memberNode, + memberNode.name, + Some(memberNode.typeFullName), + true, + isStatic + ) + + case variable => NodeTypeInfo( + variable.node, + variable.name, + Some(variable.typeFullName), + false, + false + ) + + Some(nodeTypeInfo) + end asNodeInfoOption + end FoundVariable + final case class SimpleVariable(variable: ScopeVariable) extends FoundVariable(variable) + + final case class CapturedVariable(typeDeclChain: List[NewTypeDecl], variable: ScopeVariable) + extends FoundVariable(variable) +end Scope diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala index 60f97ca8..23a3341c 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala @@ -4,20 +4,17 @@ import io.appthreat.x2cpg.Ast import scala.collection.mutable -trait TypeDeclContainer { - private val typeDeclsToAdd = mutable.ListBuffer[Ast]() - private val lambdaMethods = mutable.ListBuffer[Ast]() +trait TypeDeclContainer: + private val typeDeclsToAdd = mutable.ListBuffer[Ast]() + private val lambdaMethods = mutable.ListBuffer[Ast]() - def registerTypeDecl(typeDecl: Ast) = { - typeDeclsToAdd.append(typeDecl) - } + def registerTypeDecl(typeDecl: Ast) = + typeDeclsToAdd.append(typeDecl) - def registeredTypeDecls: List[Ast] = typeDeclsToAdd.toList + def registeredTypeDecls: List[Ast] = typeDeclsToAdd.toList - // TODO: Refactor and remove this - def registerLambdaMethod(lambdaMethod: Ast) = { - lambdaMethods.append(lambdaMethod) - } + // TODO: Refactor and remove this + def registerLambdaMethod(lambdaMethod: Ast) = + lambdaMethods.append(lambdaMethod) - def registeredLambdaMethods: List[Ast] = lambdaMethods.toList -} + def registeredLambdaMethods: List[Ast] = lambdaMethods.toList diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala index d169b9cb..99a164d9 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala @@ -18,65 +18,62 @@ class EagerSourceTypeSolver( sourceParser: SourceParser, combinedTypeSolver: SimpleCombinedTypeSolver, symbolSolver: JavaSymbolSolver -) extends TypeSolver { +) extends TypeSolver: - private val logger = LoggerFactory.getLogger(this.getClass) - private var parent: TypeSolver = _ + private val logger = LoggerFactory.getLogger(this.getClass) + private var parent: TypeSolver = _ - private val foundTypes: Map[String, SymbolReference[ResolvedReferenceTypeDeclaration]] = { - filenames - .flatMap(sourceParser.parseTypesFile) - .flatMap { cu => - symbolSolver.inject(cu) - cu.findAll(classOf[TypeDeclaration[_]]) - .asScala - .map { typeDeclaration => - val name = typeDeclaration.getFullyQualifiedName.toScala match { - case Some(fullyQualifiedName) => fullyQualifiedName - case None => - val name = typeDeclaration.getNameAsString - logger.debug(s"Could not find fully qualified name for typeDecl $name") - name + private val foundTypes: Map[String, SymbolReference[ResolvedReferenceTypeDeclaration]] = + filenames + .flatMap(sourceParser.parseTypesFile) + .flatMap { cu => + symbolSolver.inject(cu) + cu.findAll(classOf[TypeDeclaration[_]]) + .asScala + .map { typeDeclaration => + val name = typeDeclaration.getFullyQualifiedName.toScala match + case Some(fullyQualifiedName) => fullyQualifiedName + case None => + val name = typeDeclaration.getNameAsString + logger.debug( + s"Could not find fully qualified name for typeDecl $name" + ) + name + TypeSizeReducer.simplifyType(typeDeclaration) + val resolvedSymbol = Try( + SymbolReference.solved( + JavaParserFacade.get(combinedTypeSolver).getTypeDeclaration( + typeDeclaration + ) + ): SymbolReference[ResolvedReferenceTypeDeclaration] + ).getOrElse(SymbolReference.unsolved()) + name -> resolvedSymbol + } + .toList } - TypeSizeReducer.simplifyType(typeDeclaration) - val resolvedSymbol = Try( - SymbolReference.solved( - JavaParserFacade.get(combinedTypeSolver).getTypeDeclaration(typeDeclaration) - ): SymbolReference[ResolvedReferenceTypeDeclaration] - ).getOrElse(SymbolReference.unsolved()) - name -> resolvedSymbol - } - .toList - } - .toMap - } + .toMap - override def getParent: TypeSolver = parent + override def getParent: TypeSolver = parent - override def setParent(parent: TypeSolver): Unit = { - if (parent == null) { - logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") - } else if (this.parent != null) { - logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") - } else if (parent == this) { - logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") - } else { - this.parent = parent - } - } + override def setParent(parent: TypeSolver): Unit = + if parent == null then + logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") + else if this.parent != null then + logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") + else if parent == this then + logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") + else + this.parent = parent - override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = { - foundTypes.getOrElse(name, SymbolReference.unsolved()) - } -} + override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = + foundTypes.getOrElse(name, SymbolReference.unsolved()) +end EagerSourceTypeSolver -object EagerSourceTypeSolver { - def apply( - filenames: Array[String], - sourceParser: SourceParser, - combinedTypeSolver: SimpleCombinedTypeSolver, - symbolSolver: JavaSymbolSolver - ): EagerSourceTypeSolver = { - new EagerSourceTypeSolver(filenames, sourceParser, combinedTypeSolver, symbolSolver) - } -} +object EagerSourceTypeSolver: + def apply( + filenames: Array[String], + sourceParser: SourceParser, + combinedTypeSolver: SimpleCombinedTypeSolver, + symbolSolver: JavaSymbolSolver + ): EagerSourceTypeSolver = + new EagerSourceTypeSolver(filenames, sourceParser, combinedTypeSolver, symbolSolver) diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala index f6400d21..fc790c06 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala @@ -1,51 +1,45 @@ package io.appthreat.javasrc2cpg.typesolvers import better.files.File -import JmodClassPath._ +import JmodClassPath.* import javassist.ClassPath -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Try import java.io.InputStream import java.net.URL import java.util.jar.{JarEntry, JarFile} -class JmodClassPath(jmodPath: String) extends ClassPath { - private val jarfile = new JarFile(jmodPath) - private val jarfileURL = File(jmodPath).url.toString - private val entries = getEntriesMap(jarfile) - - private def entryToClassName(entry: JarEntry): String = { - entry.getName.stripPrefix(JmodClassesPrefix).stripSuffix(".class").replace('/', '.') - } - - private def getEntriesMap(jarfile: JarFile): Map[String, JarEntry] = { - jarfile - .entries() - .asScala - .filter(_.getName.startsWith(JmodClassesPrefix)) - .filter(_.getName.endsWith(".class")) - .map { entry => entryToClassName(entry) -> entry } - .toMap - } - - override def find(classname: String): URL = { - val jarname = classname.replace('.', '/') + ".class" - - if (entries.contains(classname)) { - Try(new URL(s"jmod:${jarfileURL}!/${jarname}")).getOrElse(null) - } else { null } - } - - override def openClassfile(classname: String): InputStream = { - entries.get(classname) match { - case None => null - - case Some(entry) => jarfile.getInputStream(entry) - } - } -} - -object JmodClassPath { - val JmodClassesPrefix: String = "classes/" -} +class JmodClassPath(jmodPath: String) extends ClassPath: + private val jarfile = new JarFile(jmodPath) + private val jarfileURL = File(jmodPath).url.toString + private val entries = getEntriesMap(jarfile) + + private def entryToClassName(entry: JarEntry): String = + entry.getName.stripPrefix(JmodClassesPrefix).stripSuffix(".class").replace('/', '.') + + private def getEntriesMap(jarfile: JarFile): Map[String, JarEntry] = + jarfile + .entries() + .asScala + .filter(_.getName.startsWith(JmodClassesPrefix)) + .filter(_.getName.endsWith(".class")) + .map { entry => entryToClassName(entry) -> entry } + .toMap + + override def find(classname: String): URL = + val jarname = classname.replace('.', '/') + ".class" + + if entries.contains(classname) then + Try(new URL(s"jmod:${jarfileURL}!/${jarname}")).getOrElse(null) + else null + + override def openClassfile(classname: String): InputStream = + entries.get(classname) match + case None => null + + case Some(entry) => jarfile.getInputStream(entry) +end JmodClassPath + +object JmodClassPath: + val JmodClassesPrefix: String = "classes/" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala index 038bccae..e44ed1d4 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala @@ -3,11 +3,12 @@ package io.appthreat.javasrc2cpg.typesolvers import javassist.{ClassPool, CtClass} import scala.annotation.nowarn -/** The NonCachingClassPool is meant to be used in conjuction with a type solver that already caches resolved types. - * This means that caching the intermediate ctClasses is just extra memory use. +/** The NonCachingClassPool is meant to be used in conjuction with a type solver that already caches + * resolved types. This means that caching the intermediate ctClasses is just extra memory use. * - * NonCachingClassPool extends ClassPool(useDefaultPath = false) to avoid adding the system path to the search list. + * NonCachingClassPool extends ClassPool(useDefaultPath = false) to avoid adding the system path to + * the search list. */ -class NonCachingClassPool extends ClassPool(false) { - @nowarn override def cacheCtClass(className: String, ctClass: CtClass, dynamic: Boolean): Unit = () -} +class NonCachingClassPool extends ClassPool(false): + @nowarn override def cacheCtClass(className: String, ctClass: CtClass, dynamic: Boolean): Unit = + () diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala index 28858f62..b1d881e5 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala @@ -11,88 +11,84 @@ import org.slf4j.LoggerFactory import scala.collection.mutable import scala.jdk.OptionConverters.RichOptional -class SimpleCombinedTypeSolver extends TypeSolver { +class SimpleCombinedTypeSolver extends TypeSolver: - private val logger = LoggerFactory.getLogger(this.getClass) - private var parent: TypeSolver = _ - // Ideally all types would be cached in the SimpleCombinedTypeSolver to avoid unnecessary unresolved types - // from being cached. The EagerSourceTypeSolver preloads all types, however, so separating caching and - // non-caching solvers avoids caching types twice. - private val cachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() - private val nonCachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() + private val logger = LoggerFactory.getLogger(this.getClass) + private var parent: TypeSolver = _ + // Ideally all types would be cached in the SimpleCombinedTypeSolver to avoid unnecessary unresolved types + // from being cached. The EagerSourceTypeSolver preloads all types, however, so separating caching and + // non-caching solvers avoids caching types twice. + private val cachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() + private val nonCachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() - private val typeCache = new GuavaCache( - CacheBuilder.newBuilder().build[String, SymbolReference[ResolvedReferenceTypeDeclaration]]() - ) + private val typeCache = new GuavaCache( + CacheBuilder.newBuilder().build[String, SymbolReference[ResolvedReferenceTypeDeclaration]]() + ) - def addCachingTypeSolver(typeSolver: TypeSolver): Unit = { - cachingTypeSolvers.append(typeSolver) - typeSolver.setParent(this) - } + def addCachingTypeSolver(typeSolver: TypeSolver): Unit = + cachingTypeSolvers.append(typeSolver) + typeSolver.setParent(this) - def addNonCachingTypeSolver(typeSolver: TypeSolver): Unit = { - nonCachingTypeSolvers.prepend(typeSolver) - typeSolver.setParent(this) - } + def addNonCachingTypeSolver(typeSolver: TypeSolver): Unit = + nonCachingTypeSolvers.prepend(typeSolver) + typeSolver.setParent(this) - override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = { - typeCache.get(name).toScala match { - case Some(result) => result + override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = + typeCache.get(name).toScala match + case Some(result) => result - case None => - findSolvedTypeWithSolvers(cachingTypeSolvers, name) - .getOrElse { - val result = findSolvedTypeWithSolvers(nonCachingTypeSolvers, name).getOrElse(SymbolReference.unsolved()) - typeCache.put(name, result) - result - } - } - } + case None => + findSolvedTypeWithSolvers(cachingTypeSolvers, name) + .getOrElse { + val result = findSolvedTypeWithSolvers( + nonCachingTypeSolvers, + name + ).getOrElse(SymbolReference.unsolved()) + typeCache.put(name, result) + result + } - private def findSolvedTypeWithSolvers( - typeSolvers: mutable.ArrayBuffer[TypeSolver], - className: String - ): Option[SymbolReference[ResolvedReferenceTypeDeclaration]] = { - typeSolvers.iterator - .map { typeSolver => - try { - val result = typeSolver.tryToSolveType(className): SymbolReference[ResolvedReferenceTypeDeclaration] - Option.when(result.isSolved())(result) - } catch { - case _: UnsolvedSymbolException => None - case _: StackOverflowError => None - case _: IllegalArgumentException => - // RecordDeclarations aren't handled by JavaParser yet - None - case unhandled: Throwable => - logger.debug("Caught unhandled exception", unhandled) - None - } - } - .collectFirst { case Some(symbolReference) => - symbolReference - } - } + private def findSolvedTypeWithSolvers( + typeSolvers: mutable.ArrayBuffer[TypeSolver], + className: String + ): Option[SymbolReference[ResolvedReferenceTypeDeclaration]] = + typeSolvers.iterator + .map { typeSolver => + try + val result = typeSolver.tryToSolveType(className): SymbolReference[ + ResolvedReferenceTypeDeclaration + ] + Option.when(result.isSolved())(result) + catch + case _: UnsolvedSymbolException => None + case _: StackOverflowError => None + case _: IllegalArgumentException => + // RecordDeclarations aren't handled by JavaParser yet + None + case unhandled: Throwable => + logger.debug("Caught unhandled exception", unhandled) + None + } + .collectFirst { case Some(symbolReference) => + symbolReference + } - override def solveType(name: String): ResolvedReferenceTypeDeclaration = { - val result = tryToSolveType(name) - if (result.isSolved) - result.getCorrespondingDeclaration - else - throw new UnsolvedSymbolException(name) - } + override def solveType(name: String): ResolvedReferenceTypeDeclaration = + val result = tryToSolveType(name) + if result.isSolved then + result.getCorrespondingDeclaration + else + throw new UnsolvedSymbolException(name) - override def getParent: TypeSolver = parent + override def getParent: TypeSolver = parent - override def setParent(parent: TypeSolver): Unit = { - if (parent == null) { - logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") - } else if (this.parent != null) { - logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") - } else if (parent == this) { - logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") - } else { - this.parent = parent - } - } -} + override def setParent(parent: TypeSolver): Unit = + if parent == null then + logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") + else if this.parent != null then + logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") + else if parent == this then + logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") + else + this.parent = parent +end SimpleCombinedTypeSolver diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala index 5d96e846..91ce7b4e 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala @@ -3,11 +3,11 @@ package io.appthreat.javasrc2cpg.typesolvers import com.github.javaparser.ast.`type`.{PrimitiveType, Type} import com.github.javaparser.resolution.SymbolResolver import com.github.javaparser.resolution.declarations.{ - ResolvedDeclaration, - ResolvedTypeDeclaration, - ResolvedTypeParameterDeclaration + ResolvedDeclaration, + ResolvedTypeDeclaration, + ResolvedTypeParameterDeclaration } -import com.github.javaparser.resolution.types._ +import com.github.javaparser.resolution.types.* import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap import com.github.javaparser.resolution.logic.InferenceVariableType import com.github.javaparser.resolution.model.typesystem.{LazyType, NullType} @@ -15,274 +15,258 @@ import TypeInfoCalculator.{TypeConstants, TypeNameConstants} import io.appthreat.x2cpg.datastructures.Global import org.slf4j.LoggerFactory -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional import scala.util.Try -class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver) { - private val logger = LoggerFactory.getLogger(this.getClass) - private val emptyTypeParamValues = ResolvedTypeParametersMap.empty() +class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver): + private val logger = LoggerFactory.getLogger(this.getClass) + private val emptyTypeParamValues = ResolvedTypeParametersMap.empty() - def name(typ: ResolvedType): Option[String] = { - nameOrFullName(typ, emptyTypeParamValues, fullyQualified = false) - } + def name(typ: ResolvedType): Option[String] = + nameOrFullName(typ, emptyTypeParamValues, fullyQualified = false) - def name(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = { - nameOrFullName(typ, typeParamValues, fullyQualified = false) - } + def name(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = + nameOrFullName(typ, typeParamValues, fullyQualified = false) - def fullName(typ: ResolvedType): Option[String] = { - nameOrFullName(typ, emptyTypeParamValues, fullyQualified = true).map(registerType) - } + def fullName(typ: ResolvedType): Option[String] = + nameOrFullName(typ, emptyTypeParamValues, fullyQualified = true).map(registerType) - def fullName(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = { - nameOrFullName(typ, typeParamValues, fullyQualified = true).map(registerType) - } + def fullName(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = + nameOrFullName(typ, typeParamValues, fullyQualified = true).map(registerType) - private def typesSubstituted( - trySubstitutedType: Try[ResolvedType], - typeParamDecl: ResolvedTypeParameterDeclaration - ): Boolean = { - trySubstitutedType - .map { substitutedType => - // substitutedType.isTypeVariable can crash with an UnsolvedSymbolException if it is an instance of LazyType, - // in which case the type hasn't been successfully substituted. - val substitutionOccurred = - !(substitutedType.isTypeVariable && substitutedType.asTypeParameter() == typeParamDecl) - // There's a potential infinite loop that can occur when a type variable is substituted with a wildcard type - // bounded by that type variable. - val isSimilarWildcardSubstition = substitutedType match { - case wc: ResolvedWildcard => Try(wc.getBoundedType.asTypeParameter()).toOption.contains(typeParamDecl) - case _ => false - } + private def typesSubstituted( + trySubstitutedType: Try[ResolvedType], + typeParamDecl: ResolvedTypeParameterDeclaration + ): Boolean = + trySubstitutedType + .map { substitutedType => + // substitutedType.isTypeVariable can crash with an UnsolvedSymbolException if it is an instance of LazyType, + // in which case the type hasn't been successfully substituted. + val substitutionOccurred = + !(substitutedType.isTypeVariable && substitutedType.asTypeParameter() == typeParamDecl) + // There's a potential infinite loop that can occur when a type variable is substituted with a wildcard type + // bounded by that type variable. + val isSimilarWildcardSubstition = substitutedType match + case wc: ResolvedWildcard => + Try(wc.getBoundedType.asTypeParameter()).toOption.contains(typeParamDecl) + case _ => false - substitutionOccurred && !isSimilarWildcardSubstition - } - .getOrElse(false) - } + substitutionOccurred && !isSimilarWildcardSubstition + } + .getOrElse(false) - private def nameOrFullName( - typ: ResolvedType, - typeParamValues: ResolvedTypeParametersMap, - fullyQualified: Boolean - ): Option[String] = { - typ match { - case refType: ResolvedReferenceType => - nameOrFullName(refType.getTypeDeclaration.get, fullyQualified) - case lazyType: LazyType => - lazyType match { - case _ if lazyType.isReferenceType => - nameOrFullName(lazyType.asReferenceType(), typeParamValues, fullyQualified) - case _ if lazyType.isTypeVariable => - nameOrFullName(lazyType.asTypeVariable(), typeParamValues, fullyQualified) - case _ if lazyType.isArray => - nameOrFullName(lazyType.asArrayType(), typeParamValues, fullyQualified) - case _ if lazyType.isPrimitive => - nameOrFullName(lazyType.asPrimitive(), typeParamValues, fullyQualified) - case _ if lazyType.isWildcard => - nameOrFullName(lazyType.asWildcard(), typeParamValues, fullyQualified) - } - case tpe @ (_: ResolvedVoidType | _: ResolvedPrimitiveType) => - Some(tpe.describe()) - case arrayType: ResolvedArrayType => - nameOrFullName(arrayType.getComponentType, typeParamValues, fullyQualified).map(_ + "[]") - case nullType: NullType => - Some(nullType.describe()) - case typeVariable: ResolvedTypeVariable => - val typeParamDecl = typeVariable.asTypeParameter() - val substitutedType = Try(typeParamValues.getValue(typeParamDecl)) + private def nameOrFullName( + typ: ResolvedType, + typeParamValues: ResolvedTypeParametersMap, + fullyQualified: Boolean + ): Option[String] = + typ match + case refType: ResolvedReferenceType => + nameOrFullName(refType.getTypeDeclaration.get, fullyQualified) + case lazyType: LazyType => + lazyType match + case _ if lazyType.isReferenceType => + nameOrFullName(lazyType.asReferenceType(), typeParamValues, fullyQualified) + case _ if lazyType.isTypeVariable => + nameOrFullName(lazyType.asTypeVariable(), typeParamValues, fullyQualified) + case _ if lazyType.isArray => + nameOrFullName(lazyType.asArrayType(), typeParamValues, fullyQualified) + case _ if lazyType.isPrimitive => + nameOrFullName(lazyType.asPrimitive(), typeParamValues, fullyQualified) + case _ if lazyType.isWildcard => + nameOrFullName(lazyType.asWildcard(), typeParamValues, fullyQualified) + case tpe @ (_: ResolvedVoidType | _: ResolvedPrimitiveType) => + Some(tpe.describe()) + case arrayType: ResolvedArrayType => + nameOrFullName(arrayType.getComponentType, typeParamValues, fullyQualified).map( + _ + "[]" + ) + case nullType: NullType => + Some(nullType.describe()) + case typeVariable: ResolvedTypeVariable => + val typeParamDecl = typeVariable.asTypeParameter() + val substitutedType = Try(typeParamValues.getValue(typeParamDecl)) - if (typesSubstituted(substitutedType, typeParamDecl)) { - nameOrFullName(substitutedType.get, typeParamValues, fullyQualified) - } else { - val extendsBoundOption = Try(typeParamDecl.getBounds.asScala.find(_.isExtends)).toOption.flatten - extendsBoundOption - .flatMap(bound => nameOrFullName(bound.getType, typeParamValues, fullyQualified)) - .orElse(objectType(fullyQualified)) - } - case lambdaConstraintType: ResolvedLambdaConstraintType => - nameOrFullName(lambdaConstraintType.getBound, typeParamValues, fullyQualified) - case wildcardType: ResolvedWildcard => - if (wildcardType.isBounded) { - nameOrFullName(wildcardType.getBoundedType, typeParamValues, fullyQualified) - } else { - objectType(fullyQualified) - } - case unionType: ResolvedUnionType => - Try(unionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten - .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) - .orElse(objectType(fullyQualified)) - case intersectionType: ResolvedIntersectionType => - Try(intersectionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten - .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) - .orElse(objectType(fullyQualified)) - case _: InferenceVariableType => - // From the JavaParser docs, the InferenceVariableType is: An element using during type inference. - // At this point JavaParser has failed to resolve the type. - None - } - } + if typesSubstituted(substitutedType, typeParamDecl) then + nameOrFullName(substitutedType.get, typeParamValues, fullyQualified) + else + val extendsBoundOption = + Try(typeParamDecl.getBounds.asScala.find(_.isExtends)).toOption.flatten + extendsBoundOption + .flatMap(bound => + nameOrFullName(bound.getType, typeParamValues, fullyQualified) + ) + .orElse(objectType(fullyQualified)) + case lambdaConstraintType: ResolvedLambdaConstraintType => + nameOrFullName(lambdaConstraintType.getBound, typeParamValues, fullyQualified) + case wildcardType: ResolvedWildcard => + if wildcardType.isBounded then + nameOrFullName(wildcardType.getBoundedType, typeParamValues, fullyQualified) + else + objectType(fullyQualified) + case unionType: ResolvedUnionType => + Try(unionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten + .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) + .orElse(objectType(fullyQualified)) + case intersectionType: ResolvedIntersectionType => + Try(intersectionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten + .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) + .orElse(objectType(fullyQualified)) + case _: InferenceVariableType => + // From the JavaParser docs, the InferenceVariableType is: An element using during type inference. + // At this point JavaParser has failed to resolve the type. + None - private def objectType(fullyQualified: Boolean): Option[String] = { - // Return an option type for - if (fullyQualified) { - Some(TypeConstants.Object) - } else { - Some(TypeNameConstants.Object) - } - } + private def objectType(fullyQualified: Boolean): Option[String] = + // Return an option type for + if fullyQualified then + Some(TypeConstants.Object) + else + Some(TypeNameConstants.Object) - def name(typ: Type): Option[String] = { - nameOrFullName(typ, fullyQualified = false) - } + def name(typ: Type): Option[String] = + nameOrFullName(typ, fullyQualified = false) - def fullName(typ: Type): Option[String] = { - nameOrFullName(typ, fullyQualified = true).map(registerType) - } + def fullName(typ: Type): Option[String] = + nameOrFullName(typ, fullyQualified = true).map(registerType) - private def nameOrFullName(typ: Type, fullyQualified: Boolean): Option[String] = { - typ match { - case primitiveType: PrimitiveType => - Some(primitiveType.toString) - case _ => - // We are using symbolResolver.toResolvedType() instead of typ.resolve() because - // the resolve() is just a wrapper for a call to symbolResolver.toResolvedType() - // with a specific class given as argument to which the result is casted to. - // It appears to be that ClassOrInterfaceType.resolve() is using a too restrictive - // bound (ResolvedReferenceType.class) which invalidates an otherwise successful - // resolve. Since we anyway dont care about the type cast, we directly access the - // symbolResolver and specifiy the most generic type ResolvedType. - Try(symbolResolver.toResolvedType(typ, classOf[ResolvedType])).toOption - .flatMap(resolvedType => nameOrFullName(resolvedType, emptyTypeParamValues, fullyQualified)) - } - } + private def nameOrFullName(typ: Type, fullyQualified: Boolean): Option[String] = + typ match + case primitiveType: PrimitiveType => + Some(primitiveType.toString) + case _ => + // We are using symbolResolver.toResolvedType() instead of typ.resolve() because + // the resolve() is just a wrapper for a call to symbolResolver.toResolvedType() + // with a specific class given as argument to which the result is casted to. + // It appears to be that ClassOrInterfaceType.resolve() is using a too restrictive + // bound (ResolvedReferenceType.class) which invalidates an otherwise successful + // resolve. Since we anyway dont care about the type cast, we directly access the + // symbolResolver and specifiy the most generic type ResolvedType. + Try(symbolResolver.toResolvedType(typ, classOf[ResolvedType])).toOption + .flatMap(resolvedType => + nameOrFullName(resolvedType, emptyTypeParamValues, fullyQualified) + ) - def name(decl: ResolvedDeclaration): Option[String] = { - nameOrFullName(decl, fullyQualified = false) - } + def name(decl: ResolvedDeclaration): Option[String] = + nameOrFullName(decl, fullyQualified = false) - def fullName(decl: ResolvedDeclaration): Option[String] = { - nameOrFullName(decl, fullyQualified = true).map(registerType) - } + def fullName(decl: ResolvedDeclaration): Option[String] = + nameOrFullName(decl, fullyQualified = true).map(registerType) - private def nameOrFullName(decl: ResolvedDeclaration, fullyQualified: Boolean): Option[String] = { - decl match { - case typeDecl: ResolvedTypeDeclaration => - nameOrFullName(typeDecl, fullyQualified) - } - } + private def nameOrFullName(decl: ResolvedDeclaration, fullyQualified: Boolean): Option[String] = + decl match + case typeDecl: ResolvedTypeDeclaration => + nameOrFullName(typeDecl, fullyQualified) - private def nameOrFullName(typeDecl: ResolvedTypeDeclaration, fullyQualified: Boolean): Option[String] = { - typeDecl match { - case typeParamDecl: ResolvedTypeParameterDeclaration => - if (fullyQualified) { - val containFullName = - nameOrFullName(typeParamDecl.getContainer.asInstanceOf[ResolvedDeclaration], fullyQualified = true) - containFullName.map(_ + "." + typeParamDecl.getName) - } else { - Some(typeParamDecl.getName) - } - case _ => - val typeName = Option(typeDecl.getName).getOrElse(throw new RuntimeException("TODO Investigate")) + private def nameOrFullName( + typeDecl: ResolvedTypeDeclaration, + fullyQualified: Boolean + ): Option[String] = + typeDecl match + case typeParamDecl: ResolvedTypeParameterDeclaration => + if fullyQualified then + val containFullName = + nameOrFullName( + typeParamDecl.getContainer.asInstanceOf[ResolvedDeclaration], + fullyQualified = true + ) + containFullName.map(_ + "." + typeParamDecl.getName) + else + Some(typeParamDecl.getName) + case _ => + val typeName = Option(typeDecl.getName).getOrElse( + throw new RuntimeException("TODO Investigate") + ) - // TODO Sadly we need to use a try here in order to catch the exception emitted by - // the javaparser library instead of just returning an empty option. - // In almost all cases we get here the exception is thrown. Check impact on performance - // and hopefully find a better solution if necessary. - val isInnerTypeDecl = Try(typeDecl.containerType().isPresent).getOrElse(false) - if (isInnerTypeDecl) { - nameOrFullName(typeDecl.containerType().get, fullyQualified).map(_ + "$" + typeName) - } else { - if (fullyQualified) { - val packageName = typeDecl.getPackageName + // TODO Sadly we need to use a try here in order to catch the exception emitted by + // the javaparser library instead of just returning an empty option. + // In almost all cases we get here the exception is thrown. Check impact on performance + // and hopefully find a better solution if necessary. + val isInnerTypeDecl = Try(typeDecl.containerType().isPresent).getOrElse(false) + if isInnerTypeDecl then + nameOrFullName(typeDecl.containerType().get, fullyQualified).map( + _ + "$" + typeName + ) + else if fullyQualified then + val packageName = typeDecl.getPackageName - if (packageName == null || packageName == "") { - Some(typeName) - } else { - Some(packageName + "." + typeName) - } - } else { - Some(typeName) - } - } - } - } + if packageName == null || packageName == "" then + Some(typeName) + else + Some(packageName + "." + typeName) + else + Some(typeName) - /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates TYPE nodes for each - * key in the map. Skip the `ANY` type, since this is created by default. TODO: I want the type registration not in - * here but for now it is the easiest. - */ - def registerType(typeName: String): String = { - if (typeName != "ANY") { - global.usedTypes.putIfAbsent(typeName, true) - } - typeName - } -} + /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates + * TYPE nodes for each key in the map. Skip the `ANY` type, since this is created by default. + * TODO: I want the type registration not in here but for now it is the easiest. + */ + def registerType(typeName: String): String = + if typeName != "ANY" then + global.usedTypes.putIfAbsent(typeName, true) + typeName +end TypeInfoCalculator -object TypeInfoCalculator { - def isAutocastType(typeName: String): Boolean = { - NumericTypes.contains(typeName) - } +object TypeInfoCalculator: + def isAutocastType(typeName: String): Boolean = + NumericTypes.contains(typeName) - object TypeConstants { - val Byte: String = "byte" - val Short: String = "short" - val Int: String = "int" - val Long: String = "long" - val Float: String = "float" - val Double: String = "double" - val Char: String = "char" - val Boolean: String = "boolean" - val Object: String = "java.lang.Object" - val Class: String = "java.lang.Class" - val Iterator: String = "java.util.Iterator" - val Void: String = "void" - val Any: String = "ANY" - } + object TypeConstants: + val Byte: String = "byte" + val Short: String = "short" + val Int: String = "int" + val Long: String = "long" + val Float: String = "float" + val Double: String = "double" + val Char: String = "char" + val Boolean: String = "boolean" + val Object: String = "java.lang.Object" + val Class: String = "java.lang.Class" + val Iterator: String = "java.util.Iterator" + val Void: String = "void" + val Any: String = "ANY" - object TypeNameConstants { - val Object: String = "Object" - } + object TypeNameConstants: + val Object: String = "Object" - // The method signatures for all methods implemented by java.lang.Object, as returned by JavaParser. This is used - // to filter out Object methods when determining which functional interface method a lambda implements. See - // https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for more details. - val ObjectMethodSignatures: Set[String] = Set( - "wait(long, int)", - "equals(java.lang.Object)", - "clone()", - "toString()", - "wait()", - "hashCode()", - "getClass()", - "notify()", - "finalize()", - "wait(long)", - "notifyAll()", - "registerNatives()" - ) + // The method signatures for all methods implemented by java.lang.Object, as returned by JavaParser. This is used + // to filter out Object methods when determining which functional interface method a lambda implements. See + // https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for more details. + val ObjectMethodSignatures: Set[String] = Set( + "wait(long, int)", + "equals(java.lang.Object)", + "clone()", + "toString()", + "wait()", + "hashCode()", + "getClass()", + "notify()", + "finalize()", + "wait(long)", + "notifyAll()", + "registerNatives()" + ) - val NumericTypes: Set[String] = Set( - "byte", - "short", - "int", - "long", - "float", - "double", - "char", - "boolean", - "java.lang.Byte", - "java.lang.Short", - "java.lang.Integer", - "java.lang.Long", - "java.lang.Float", - "java.lang.Double", - "java.lang.Character", - "java.lang.Boolean" - ) + val NumericTypes: Set[String] = Set( + "byte", + "short", + "int", + "long", + "float", + "double", + "char", + "boolean", + "java.lang.Byte", + "java.lang.Short", + "java.lang.Integer", + "java.lang.Long", + "java.lang.Float", + "java.lang.Double", + "java.lang.Character", + "java.lang.Boolean" + ) - def apply(global: Global, symbolResolver: SymbolResolver): TypeInfoCalculator = { - new TypeInfoCalculator(global, symbolResolver) - } -} + def apply(global: Global, symbolResolver: SymbolResolver): TypeInfoCalculator = + new TypeInfoCalculator(global, symbolResolver) +end TypeInfoCalculator diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala index d18010a8..8d1c0bc7 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala @@ -3,16 +3,14 @@ package io.appthreat.javasrc2cpg.typesolvers import com.github.javaparser.ast.body.TypeDeclaration import com.github.javaparser.ast.stmt.BlockStmt -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -object TypeSizeReducer { - def simplifyType(typeDeclaration: TypeDeclaration[_]): Unit = { - typeDeclaration - .getMethods() - .asScala - .filter(method => method.getBody().isPresent()) - .foreach { method => - method.setBody(new BlockStmt()) - } - } -} +object TypeSizeReducer: + def simplifyType(typeDeclaration: TypeDeclaration[?]): Unit = + typeDeclaration + .getMethods() + .asScala + .filter(method => method.getBody().isPresent()) + .foreach { method => + method.setBody(new BlockStmt()) + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala index 7b6ac671..f56a611a 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala @@ -16,195 +16,180 @@ import scala.collection.mutable import scala.jdk.CollectionConverters.* import scala.util.{Failure, Success, Try, Using} -class JdkJarTypeSolver extends TypeSolver { - - private val logger = LoggerFactory.getLogger(this.getClass) - - private var parent: Option[TypeSolver] = None - private val classPool = new NonCachingClassPool() - - val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty - - // Populating this causes memory leaks - val packagesJarMappings: mutable.Map[String, mutable.Set[String]] = mutable.Map.empty - - private type RefType = ResolvedReferenceTypeDeclaration - - override def getParent(): TypeSolver = parent.get - - override def setParent(parent: TypeSolver): Unit = { - this.parent match { - case None => - this.parent = Some(parent) - - case Some(_) => - throw new RuntimeException("JdkJarTypeSolver parent may only be set once") - } - } - - override def tryToSolveType(javaParserName: String): SymbolReference[ResolvedReferenceTypeDeclaration] = { - val packagePrefix = packagePrefixForJavaParserName(javaParserName) - if (knownPackagePrefixes.contains(packagePrefix)) { - lookupType(javaParserName) - } else { - SymbolReference.unsolved() - } - } - - private def lookupType(javaParserName: String): SymbolReference[ResolvedReferenceTypeDeclaration] = { - val name = convertJavaParserNameToStandard(javaParserName) - Try(classPool.get(name)) match { - case Success(ctClass) => - val refType = ctClassToRefType(ctClass) - refTypeToSymbolReference(refType) - - case Failure(e) => - SymbolReference.unsolved() - } - } - - override def solveType(name: String): ResolvedReferenceTypeDeclaration = { - tryToSolveType(name) match { - case symbolReference if symbolReference.isSolved => - symbolReference.getCorrespondingDeclaration - - case _ => throw new UnsolvedSymbolException(name) - } - } - - private def ctClassToRefType(ctClass: CtClass): RefType = { - JavassistFactory.toTypeDeclaration(ctClass, getRoot) - } - - private def refTypeToSymbolReference(refType: RefType): SymbolReference[RefType] = { - SymbolReference.solved[RefType, RefType](refType) - } - - private def addPathToClassPool(archivePath: String): Try[ClassPath] = { - if (archivePath.isJarPath) { - Try(classPool.appendClassPath(archivePath)) - } else if (archivePath.isJmodPath) { - val classPath = new JmodClassPath(archivePath) - Try(classPool.appendClassPath(classPath)) - } else { - Failure(new IllegalArgumentException("$archivePath is not a path to a jar/jmod")) - } - } - - def withJars(archivePaths: Seq[String]): JdkJarTypeSolver = { - addArchives(archivePaths) - this - } - - def addArchives(archivePaths: Seq[String]): Unit = { - archivePaths.foreach { archivePath => - addPathToClassPool(archivePath) match { - case Success(_) => registerPackagesForJar(archivePath) - - case Failure(e) => - logger.debug(s"Could not load jar at path $archivePath", e.getMessage) - } - } - } - - private def registerPackagesForJar(archivePath: String): Unit = { - val entryNameConverter = if (archivePath.isJarPath) packagePrefixForJarEntry else packagePrefixForJmodEntry - try { - Using(new JarFile(archivePath)) { jarFile => - def jarPackages = jarFile - .entries() - .asIterator() - .asScala - .filter(entry => - !entry.isDirectory && !entry.getName - .startsWith("module-info") && (entry.getName.endsWith(ClassExtension) || entry.getName - .endsWith(JavaExtension) || entry.getName.endsWith(KtExtension)) - ) - knownPackagePrefixes ++= jarPackages.map(entry => entryNameConverter(entry.getName)) - } - } catch { - case ioException: IOException => - logger.debug(s"Could register classes for archive at $archivePath", ioException.getMessage) - } - } -} - -object JdkJarTypeSolver { - val ClassExtension: String = ".class" - val JavaExtension: String = ".java" - val KtExtension: String = ".kt" - val JmodClassPrefix: String = "classes/" - val JarExtension: String = ".jar" - val JmodExtension: String = ".jmod" - - extension (path: String) { - def isJarPath: Boolean = path.endsWith(JarExtension) - def isJmodPath: Boolean = path.endsWith(JmodExtension) - } - - def fromJdkPath(jdkPath: String): JdkJarTypeSolver = { - val jarPaths = SourceFiles.determine(jdkPath, Set(JarExtension, JmodExtension)) - if (jarPaths.isEmpty) { - throw new IllegalArgumentException(s"No .jar or .jmod files found at JDK path ${jdkPath}") - } - new JdkJarTypeSolver().withJars(jarPaths) - } - - /** Convert JavaParser class name foo.bar.qux.Baz to package prefix foo.bar Only use first 2 parts since this is - * sufficient to deterimine whether a class has been registered in most cases and, if not, the failure is just a slow - * lookup. - */ - def packagePrefixForJavaParserName(className: String): String = { - className.split("\\.").take(2).mkString(".") - } - - /** Convert Jar entry name foo/bar/qux/Baz.class to package prefix foo.bar Only use first 2 parts since this is - * sufficient to deterimine whether a class has been registered in most cases and, if not, the failure is just a slow - * lookup. - */ - def packagePrefixForJarEntry(entryName: String): String = { - entryName.split("/").take(2).mkString(".") - } - - /** Convert jmod entry name classes/foo/bar/qux/Baz.class to package prefix foo.bar Only use first 2 parts since this - * is sufficient to deterimine whether a class has been registered in most cases and, if not, the failure is just a - * slow lookup. - */ - def packagePrefixForJmodEntry(entryName: String): String = { - packagePrefixForJarEntry(entryName.stripPrefix(JmodClassPrefix)) - } - - /** A name is assumed to contain at least one subclass (e.g. ...Foo$Bar) if the last name part starts with a digit, or - * if the last 2 name parts start with capital letters. This heuristic is based on the class name format in the JDK - * jars, where names with subclasses have one of the forms: - * - java.lang.ClassLoader$2 - * - java.lang.ClassLoader$NativeLibrary - * - java.lang.ClassLoader$NativeLibrary$Unloader - */ - private def namePartsContainSubclass(nameParts: Array[String]): Boolean = { - nameParts.takeRight(2) match { - case Array() => false - - case Array(singlePart) => false - - case Array(secondLast, last) => - last.head.isDigit || (secondLast.head.isUpper && last.head.isUpper) - } - } - - /** JavaParser replaces the `$` in nested class names with a `.`. This method converts the JavaParser names to the - * standard format by replacing the `.` between name parts that start with a capital letter or a digit with a `$` - * since the jdk classes follow the standard practice of capitalising the first letter in class names but not package - * names. - */ - def convertJavaParserNameToStandard(className: String): String = { - className.split(".") match { - case nameParts if namePartsContainSubclass(nameParts) => - val (packagePrefix, classNames) = nameParts.partition(_.head.isLower) - s"${packagePrefix.mkString(".")}.${classNames.mkString("$")}" - - case _ => className - } - - } -} +class JdkJarTypeSolver extends TypeSolver: + + private val logger = LoggerFactory.getLogger(this.getClass) + + private var parent: Option[TypeSolver] = None + private val classPool = new NonCachingClassPool() + + val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty + + // Populating this causes memory leaks + val packagesJarMappings: mutable.Map[String, mutable.Set[String]] = mutable.Map.empty + + private type RefType = ResolvedReferenceTypeDeclaration + + override def getParent(): TypeSolver = parent.get + + override def setParent(parent: TypeSolver): Unit = + this.parent match + case None => + this.parent = Some(parent) + + case Some(_) => + throw new RuntimeException("JdkJarTypeSolver parent may only be set once") + + override def tryToSolveType(javaParserName: String) + : SymbolReference[ResolvedReferenceTypeDeclaration] = + val packagePrefix = packagePrefixForJavaParserName(javaParserName) + if knownPackagePrefixes.contains(packagePrefix) then + lookupType(javaParserName) + else + SymbolReference.unsolved() + + private def lookupType(javaParserName: String) + : SymbolReference[ResolvedReferenceTypeDeclaration] = + val name = convertJavaParserNameToStandard(javaParserName) + Try(classPool.get(name)) match + case Success(ctClass) => + val refType = ctClassToRefType(ctClass) + refTypeToSymbolReference(refType) + + case Failure(e) => + SymbolReference.unsolved() + + override def solveType(name: String): ResolvedReferenceTypeDeclaration = + tryToSolveType(name) match + case symbolReference if symbolReference.isSolved => + symbolReference.getCorrespondingDeclaration + + case _ => throw new UnsolvedSymbolException(name) + + private def ctClassToRefType(ctClass: CtClass): RefType = + JavassistFactory.toTypeDeclaration(ctClass, getRoot) + + private def refTypeToSymbolReference(refType: RefType): SymbolReference[RefType] = + SymbolReference.solved[RefType, RefType](refType) + + private def addPathToClassPool(archivePath: String): Try[ClassPath] = + if archivePath.isJarPath then + Try(classPool.appendClassPath(archivePath)) + else if archivePath.isJmodPath then + val classPath = new JmodClassPath(archivePath) + Try(classPool.appendClassPath(classPath)) + else + Failure(new IllegalArgumentException("$archivePath is not a path to a jar/jmod")) + + def withJars(archivePaths: Seq[String]): JdkJarTypeSolver = + addArchives(archivePaths) + this + + def addArchives(archivePaths: Seq[String]): Unit = + archivePaths.foreach { archivePath => + addPathToClassPool(archivePath) match + case Success(_) => registerPackagesForJar(archivePath) + + case Failure(e) => + logger.debug(s"Could not load jar at path $archivePath", e.getMessage) + } + + private def registerPackagesForJar(archivePath: String): Unit = + val entryNameConverter = + if archivePath.isJarPath then packagePrefixForJarEntry else packagePrefixForJmodEntry + try + Using(new JarFile(archivePath)) { jarFile => + def jarPackages = jarFile + .entries() + .asIterator() + .asScala + .filter(entry => + !entry.isDirectory && !entry.getName + .startsWith("module-info") && (entry.getName.endsWith( + ClassExtension + ) || entry.getName + .endsWith(JavaExtension) || entry.getName.endsWith(KtExtension)) + ) + knownPackagePrefixes ++= jarPackages.map(entry => entryNameConverter(entry.getName)) + } + catch + case ioException: IOException => + logger.debug( + s"Could register classes for archive at $archivePath", + ioException.getMessage + ) + end try + end registerPackagesForJar +end JdkJarTypeSolver + +object JdkJarTypeSolver: + val ClassExtension: String = ".class" + val JavaExtension: String = ".java" + val KtExtension: String = ".kt" + val JmodClassPrefix: String = "classes/" + val JarExtension: String = ".jar" + val JmodExtension: String = ".jmod" + + extension (path: String) + def isJarPath: Boolean = path.endsWith(JarExtension) + def isJmodPath: Boolean = path.endsWith(JmodExtension) + + def fromJdkPath(jdkPath: String): JdkJarTypeSolver = + val jarPaths = SourceFiles.determine(jdkPath, Set(JarExtension, JmodExtension)) + if jarPaths.isEmpty then + throw new IllegalArgumentException( + s"No .jar or .jmod files found at JDK path ${jdkPath}" + ) + new JdkJarTypeSolver().withJars(jarPaths) + + /** Convert JavaParser class name foo.bar.qux.Baz to package prefix foo.bar Only use first 2 + * parts since this is sufficient to deterimine whether a class has been registered in most + * cases and, if not, the failure is just a slow lookup. + */ + def packagePrefixForJavaParserName(className: String): String = + className.split("\\.").take(2).mkString(".") + + /** Convert Jar entry name foo/bar/qux/Baz.class to package prefix foo.bar Only use first 2 + * parts since this is sufficient to deterimine whether a class has been registered in most + * cases and, if not, the failure is just a slow lookup. + */ + def packagePrefixForJarEntry(entryName: String): String = + entryName.split("/").take(2).mkString(".") + + /** Convert jmod entry name classes/foo/bar/qux/Baz.class to package prefix foo.bar Only use + * first 2 parts since this is sufficient to deterimine whether a class has been registered in + * most cases and, if not, the failure is just a slow lookup. + */ + def packagePrefixForJmodEntry(entryName: String): String = + packagePrefixForJarEntry(entryName.stripPrefix(JmodClassPrefix)) + + /** A name is assumed to contain at least one subclass (e.g. ...Foo$Bar) if the last name part + * starts with a digit, or if the last 2 name parts start with capital letters. This heuristic + * is based on the class name format in the JDK jars, where names with subclasses have one of + * the forms: + * - java.lang.ClassLoader$2 + * - java.lang.ClassLoader$NativeLibrary + * - java.lang.ClassLoader$NativeLibrary$Unloader + */ + private def namePartsContainSubclass(nameParts: Array[String]): Boolean = + nameParts.takeRight(2) match + case Array() => false + + case Array(singlePart) => false + + case Array(secondLast, last) => + last.head.isDigit || (secondLast.head.isUpper && last.head.isUpper) + + /** JavaParser replaces the `$` in nested class names with a `.`. This method converts the + * JavaParser names to the standard format by replacing the `.` between name parts that start + * with a capital letter or a digit with a `$` since the jdk classes follow the standard + * practice of capitalising the first letter in class names but not package names. + */ + def convertJavaParserNameToStandard(className: String): String = + className.split(".") match + case nameParts if namePartsContainSubclass(nameParts) => + val (packagePrefix, classNames) = nameParts.partition(_.head.isLower) + s"${packagePrefix.mkString(".")}.${classNames.mkString("$")}" + + case _ => className +end JdkJarTypeSolver diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala index 0dff1f3a..3e88a083 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala @@ -4,91 +4,99 @@ import scala.collection.mutable case class BindingTableEntry(name: String, signature: String, implementingMethodFullName: String) -class BindingTable() { - private val entries = mutable.Map.empty[String, BindingTableEntry] +class BindingTable(): + private val entries = mutable.Map.empty[String, BindingTableEntry] - def add(entry: BindingTableEntry): Unit = { - entries.put(entry.name + entry.signature, entry) - } + def add(entry: BindingTableEntry): Unit = + entries.put(entry.name + entry.signature, entry) - def getEntries: Iterable[BindingTableEntry] = { - entries.values - } -} + def getEntries: Iterable[BindingTableEntry] = + entries.values -trait BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap] { - def directParents(typeDecl: InputTypeDecl): collection.Seq[AstTypeDecl] +trait BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]: + def directParents(typeDecl: InputTypeDecl): collection.Seq[AstTypeDecl] - def allParentsWithTypeMap(typeDecl: InputTypeDecl): collection.Seq[(AstTypeDecl, TypeMap)] + def allParentsWithTypeMap(typeDecl: InputTypeDecl): collection.Seq[(AstTypeDecl, TypeMap)] - def directBindingTableEntries(typeDeclFullName: String, typeDecl: InputTypeDecl): collection.Seq[BindingTableEntry] + def directBindingTableEntries( + typeDeclFullName: String, + typeDecl: InputTypeDecl + ): collection.Seq[BindingTableEntry] - def getDeclaredMethods(typeDecl: AstTypeDecl): Iterable[(String, AstMethodDecl)] + def getDeclaredMethods(typeDecl: AstTypeDecl): Iterable[(String, AstMethodDecl)] - def getMethodSignature(methodDecl: AstMethodDecl, typeMap: TypeMap): String + def getMethodSignature(methodDecl: AstMethodDecl, typeMap: TypeMap): String - def getMethodSignatureForEmptyTypeMap(methodDecl: AstMethodDecl): String + def getMethodSignatureForEmptyTypeMap(methodDecl: AstMethodDecl): String - def typeDeclEquals(astTypeDecl: AstTypeDecl, inputTypeDecl: InputTypeDecl): Boolean -} + def typeDeclEquals(astTypeDecl: AstTypeDecl, inputTypeDecl: InputTypeDecl): Boolean -object BindingTable { +object BindingTable: - def createBindingTable[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]( - typeDeclFullName: String, - typeDecl: InputTypeDecl, - getBindingTable: AstTypeDecl => BindingTable, - adapter: BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap] - ): BindingTable = { - val bindingTable = new BindingTable() + def createBindingTable[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]( + typeDeclFullName: String, + typeDecl: InputTypeDecl, + getBindingTable: AstTypeDecl => BindingTable, + adapter: BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap] + ): BindingTable = + val bindingTable = new BindingTable() - // Take over all binding table entries for parent class/interface binding tables. - adapter.directParents(typeDecl).filterNot(adapter.typeDeclEquals(_, typeDecl)).foreach { parentTypeDecl => - val parentBindingTable = - try { - getBindingTable(parentTypeDecl) - } catch { - case e: StackOverflowError => - throw new RuntimeException(s"SOE getting binding table for $typeDeclFullName") + // Take over all binding table entries for parent class/interface binding tables. + adapter.directParents(typeDecl).filterNot(adapter.typeDeclEquals(_, typeDecl)).foreach { + parentTypeDecl => + val parentBindingTable = + try + getBindingTable(parentTypeDecl) + catch + case e: StackOverflowError => + throw new RuntimeException( + s"SOE getting binding table for $typeDeclFullName" + ) + parentBindingTable.getEntries.foreach { entry => + bindingTable.add(entry) + } } - parentBindingTable.getEntries.foreach { entry => - bindingTable.add(entry) - } - } - - // Create table entries for all methods declared in type declaration. - val directTableEntries = adapter.directBindingTableEntries(typeDeclFullName, typeDecl) - - // Add all table entries for method of type declaration to binding table. - // It is important that this happens after adding the inherited entries - // because later entries for the same slot (same name and signature) - // override previously added entries. - directTableEntries.foreach(bindingTable.add) - - // Override the bindings for generic base class methods if they are overriden. - // To do so we need to traverse all methods in all parent type and calculate - // their signature in the derived type declarations context, meaning with the - // concrete values for the generic type parameters. If this signature together - // with the name matches a direct table entry we have an override and replace - // the binding table entry for the erased! parent method signature. - // This become necessary because calls in the JVM executed via erased signatures. - adapter.allParentsWithTypeMap(typeDecl).foreach { case (parentTypeDecl, typeParameterInDerivedContext) => - directTableEntries.foreach { directTableEntry => - val parentMethods = adapter.getDeclaredMethods(parentTypeDecl) - parentMethods.foreach { case (parentName, parentMethodDecl) => - if (directTableEntry.name == parentName) { - val parentSigInDerivedContext = adapter.getMethodSignature(parentMethodDecl, typeParameterInDerivedContext) - if (directTableEntry.signature == parentSigInDerivedContext) { - val erasedParentMethodSig = adapter.getMethodSignatureForEmptyTypeMap(parentMethodDecl) - val tableEntry = BindingTableEntry - .apply(directTableEntry.name, erasedParentMethodSig, directTableEntry.implementingMethodFullName) - bindingTable.add(tableEntry) - } - } + + // Create table entries for all methods declared in type declaration. + val directTableEntries = adapter.directBindingTableEntries(typeDeclFullName, typeDecl) + + // Add all table entries for method of type declaration to binding table. + // It is important that this happens after adding the inherited entries + // because later entries for the same slot (same name and signature) + // override previously added entries. + directTableEntries.foreach(bindingTable.add) + + // Override the bindings for generic base class methods if they are overriden. + // To do so we need to traverse all methods in all parent type and calculate + // their signature in the derived type declarations context, meaning with the + // concrete values for the generic type parameters. If this signature together + // with the name matches a direct table entry we have an override and replace + // the binding table entry for the erased! parent method signature. + // This become necessary because calls in the JVM executed via erased signatures. + adapter.allParentsWithTypeMap(typeDecl).foreach { + case (parentTypeDecl, typeParameterInDerivedContext) => + directTableEntries.foreach { directTableEntry => + val parentMethods = adapter.getDeclaredMethods(parentTypeDecl) + parentMethods.foreach { case (parentName, parentMethodDecl) => + if directTableEntry.name == parentName then + val parentSigInDerivedContext = adapter.getMethodSignature( + parentMethodDecl, + typeParameterInDerivedContext + ) + if directTableEntry.signature == parentSigInDerivedContext then + val erasedParentMethodSig = + adapter.getMethodSignatureForEmptyTypeMap(parentMethodDecl) + val tableEntry = BindingTableEntry + .apply( + directTableEntry.name, + erasedParentMethodSig, + directTableEntry.implementingMethodFullName + ) + bindingTable.add(tableEntry) + } + } } - } - } - bindingTable - } -} + bindingTable + end createBindingTable +end BindingTable diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala index 1c708209..4d0839ac 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala @@ -1,6 +1,9 @@ package io.appthreat.javasrc2cpg.util -import com.github.javaparser.resolution.declarations.{ResolvedMethodDeclaration, ResolvedReferenceTypeDeclaration} +import com.github.javaparser.resolution.declarations.{ + ResolvedMethodDeclaration, + ResolvedReferenceTypeDeclaration +} import com.github.javaparser.resolution.types.ResolvedReferenceType import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserAnnotationDeclaration @@ -10,21 +13,19 @@ import Util.{composeMethodFullName, getAllParents, safeGetAncestors} import io.shiftleft.codepropertygraph.generated.nodes.NewBinding import scala.jdk.OptionConverters.RichOptional -import scala.jdk.CollectionConverters._ - -object Shared { - def getDeclaredMethods(typeDecl: ResolvedReferenceTypeDeclaration): Iterable[ResolvedMethodDeclaration] = { - typeDecl match { - // Attempting to get declared methods for annotations throws an UnsupportedOperationException. - // See https://github.com/javaparser/javaparser/issues/1838 for details. - case _: JavaParserAnnotationDeclaration => Set.empty - case _: ReflectionAnnotationDeclaration => Set.empty - case _: JavassistAnnotationDeclaration => Set.empty - - case _ => typeDecl.getDeclaredMethods.asScala - } - } -} +import scala.jdk.CollectionConverters.* + +object Shared: + def getDeclaredMethods(typeDecl: ResolvedReferenceTypeDeclaration) + : Iterable[ResolvedMethodDeclaration] = + typeDecl match + // Attempting to get declared methods for annotations throws an UnsupportedOperationException. + // See https://github.com/javaparser/javaparser/issues/1838 for details. + case _: JavaParserAnnotationDeclaration => Set.empty + case _: ReflectionAnnotationDeclaration => Set.empty + case _: JavassistAnnotationDeclaration => Set.empty + + case _ => typeDecl.getDeclaredMethods.asScala class BindingTableAdapterForJavaparser( methodSignatureImpl: (ResolvedMethodDeclaration, ResolvedTypeParametersMap) => String @@ -33,60 +34,56 @@ class BindingTableAdapterForJavaparser( ResolvedReferenceTypeDeclaration, ResolvedMethodDeclaration, ResolvedTypeParametersMap - ] { - - override def directParents( - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[ResolvedReferenceTypeDeclaration] = { - safeGetAncestors(typeDecl).map(_.getTypeDeclaration.get) - } - - override def allParentsWithTypeMap( - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = { - getAllParents(typeDecl).map { parentType => - (parentType.getTypeDeclaration.get, parentType.typeParametersMap()) - } - } - - override def directBindingTableEntries( - typeDeclFullName: String, - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[BindingTableEntry] = { - getDeclaredMethods(typeDecl) - .filter { case (_, methodDecl) => !methodDecl.isStatic } - .map { case (_, methodDecl) => - val signature = getMethodSignature(methodDecl, ResolvedTypeParametersMap.empty()) - BindingTableEntry.apply( - methodDecl.getName, - signature, - composeMethodFullName(typeDeclFullName, methodDecl.getName, signature) - ) - } - .toBuffer - } - - override def getDeclaredMethods( - typeDecl: ResolvedReferenceTypeDeclaration - ): Iterable[(String, ResolvedMethodDeclaration)] = { - Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) - } - - override def getMethodSignature(methodDecl: ResolvedMethodDeclaration, typeMap: ResolvedTypeParametersMap): String = { - methodSignatureImpl(methodDecl, typeMap) - } - - override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = { - methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) - } - - override def typeDeclEquals( - astTypeDecl: ResolvedReferenceTypeDeclaration, - inputTypeDecl: ResolvedReferenceTypeDeclaration - ): Boolean = { - astTypeDecl.getQualifiedName == inputTypeDecl.getQualifiedName - } -} + ]: + + override def directParents( + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[ResolvedReferenceTypeDeclaration] = + safeGetAncestors(typeDecl).map(_.getTypeDeclaration.get) + + override def allParentsWithTypeMap( + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = + getAllParents(typeDecl).map { parentType => + (parentType.getTypeDeclaration.get, parentType.typeParametersMap()) + } + + override def directBindingTableEntries( + typeDeclFullName: String, + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[BindingTableEntry] = + getDeclaredMethods(typeDecl) + .filter { case (_, methodDecl) => !methodDecl.isStatic } + .map { case (_, methodDecl) => + val signature = getMethodSignature(methodDecl, ResolvedTypeParametersMap.empty()) + BindingTableEntry.apply( + methodDecl.getName, + signature, + composeMethodFullName(typeDeclFullName, methodDecl.getName, signature) + ) + } + .toBuffer + + override def getDeclaredMethods( + typeDecl: ResolvedReferenceTypeDeclaration + ): Iterable[(String, ResolvedMethodDeclaration)] = + Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) + + override def getMethodSignature( + methodDecl: ResolvedMethodDeclaration, + typeMap: ResolvedTypeParametersMap + ): String = + methodSignatureImpl(methodDecl, typeMap) + + override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = + methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) + + override def typeDeclEquals( + astTypeDecl: ResolvedReferenceTypeDeclaration, + inputTypeDecl: ResolvedReferenceTypeDeclaration + ): Boolean = + astTypeDecl.getQualifiedName == inputTypeDecl.getQualifiedName +end BindingTableAdapterForJavaparser case class LambdaBindingInfo( fullName: String, @@ -101,49 +98,48 @@ class BindingTableAdapterForLambdas( ResolvedReferenceTypeDeclaration, ResolvedMethodDeclaration, ResolvedTypeParametersMap - ] { - - override def directParents(lambdaBindingInfo: LambdaBindingInfo): collection.Seq[ResolvedReferenceTypeDeclaration] = { - lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList - } - - override def allParentsWithTypeMap( - lambdaBindingInfo: LambdaBindingInfo - ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = { - val nonDirectParents = - lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList.flatMap(getAllParents) - (lambdaBindingInfo.implementedType.toList ++ nonDirectParents).map { typ => - (typ.getTypeDeclaration.get, typ.typeParametersMap()) - } - } - - override def directBindingTableEntries( - typeDeclFullName: String, - lambdaBindingInfo: LambdaBindingInfo - ): collection.Seq[BindingTableEntry] = { - lambdaBindingInfo.directBinding.map { binding => - BindingTableEntry(binding.name, binding.signature, binding.methodFullName) - }.toList - } - - override def getDeclaredMethods( - typeDecl: ResolvedReferenceTypeDeclaration - ): Iterable[(String, ResolvedMethodDeclaration)] = { - Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) - } - - override def getMethodSignature(methodDecl: ResolvedMethodDeclaration, typeMap: ResolvedTypeParametersMap): String = { - methodSignatureImpl(methodDecl, typeMap) - } - - override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = { - methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) - } - - override def typeDeclEquals( - astTypeDecl: ResolvedReferenceTypeDeclaration, - inputTypeDecl: LambdaBindingInfo - ): Boolean = { - astTypeDecl.getQualifiedName == inputTypeDecl.fullName - } -} + ]: + + override def directParents(lambdaBindingInfo: LambdaBindingInfo) + : collection.Seq[ResolvedReferenceTypeDeclaration] = + lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList + + override def allParentsWithTypeMap( + lambdaBindingInfo: LambdaBindingInfo + ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = + val nonDirectParents = + lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList.flatMap( + getAllParents + ) + (lambdaBindingInfo.implementedType.toList ++ nonDirectParents).map { typ => + (typ.getTypeDeclaration.get, typ.typeParametersMap()) + } + + override def directBindingTableEntries( + typeDeclFullName: String, + lambdaBindingInfo: LambdaBindingInfo + ): collection.Seq[BindingTableEntry] = + lambdaBindingInfo.directBinding.map { binding => + BindingTableEntry(binding.name, binding.signature, binding.methodFullName) + }.toList + + override def getDeclaredMethods( + typeDecl: ResolvedReferenceTypeDeclaration + ): Iterable[(String, ResolvedMethodDeclaration)] = + Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) + + override def getMethodSignature( + methodDecl: ResolvedMethodDeclaration, + typeMap: ResolvedTypeParametersMap + ): String = + methodSignatureImpl(methodDecl, typeMap) + + override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = + methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) + + override def typeDeclEquals( + astTypeDecl: ResolvedReferenceTypeDeclaration, + inputTypeDecl: LambdaBindingInfo + ): Boolean = + astTypeDecl.getQualifiedName == inputTypeDecl.fullName +end BindingTableAdapterForLambdas diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala index bd77f6cd..8f39cacf 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala @@ -8,80 +8,77 @@ import org.slf4j.LoggerFactory import java.nio.file.Paths import scala.util.{Failure, Success, Try} -object Delombok { +object Delombok: - sealed trait DelombokMode - // Don't run delombok at all. - object DelombokMode { - case object NoDelombok extends DelombokMode - case object Default extends DelombokMode - case object TypesOnly extends DelombokMode - case object RunDelombok extends DelombokMode - } + sealed trait DelombokMode + // Don't run delombok at all. + object DelombokMode: + case object NoDelombok extends DelombokMode + case object Default extends DelombokMode + case object TypesOnly extends DelombokMode + case object RunDelombok extends DelombokMode - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - private def systemJavaPath: String = { - sys.env - .get("JAVA_HOME") - .flatMap { javaHome => - val javaExecutable = File(javaHome, "bin", "java") - Option.when(javaExecutable.exists && javaExecutable.isExecutable) { - javaExecutable.canonicalPath - } - } - .getOrElse("java") - } + private def systemJavaPath: String = + sys.env + .get("JAVA_HOME") + .flatMap { javaHome => + val javaExecutable = File(javaHome, "bin", "java") + Option.when(javaExecutable.exists && javaExecutable.isExecutable) { + javaExecutable.canonicalPath + } + } + .getOrElse("java") - private def delombokToTempDirCommand(tempDir: File, analysisJavaHome: Option[String]) = { - val javaPath = analysisJavaHome.getOrElse(systemJavaPath) - val classPathArg = Try(File.newTemporaryFile("classpath").deleteOnExit()) match { - case Success(file) => - // Write classpath to a file to work around Windows length limits. - file.write(System.getProperty("java.class.path")) - s"@${file.canonicalPath}" + private def delombokToTempDirCommand(tempDir: File, analysisJavaHome: Option[String]) = + val javaPath = analysisJavaHome.getOrElse(systemJavaPath) + val classPathArg = Try(File.newTemporaryFile("classpath").deleteOnExit()) match + case Success(file) => + // Write classpath to a file to work around Windows length limits. + file.write(System.getProperty("java.class.path")) + s"@${file.canonicalPath}" - case Failure(t) => - logger.debug( - s"Failed to create classpath file for delombok execution. Results may be missing on Windows systems", - t - ) - System.getProperty("java.class.path") - } - s"$javaPath -cp $classPathArg lombok.launch.Main delombok . -d ${tempDir.canonicalPath}" - } + case Failure(t) => + logger.debug( + s"Failed to create classpath file for delombok execution. Results may be missing on Windows systems", + t + ) + System.getProperty("java.class.path") + s"$javaPath -cp $classPathArg lombok.launch.Main delombok . -d ${tempDir.canonicalPath}" - def run(projectDir: String, analysisJavaHome: Option[String]): String = { - Try(File.newTemporaryDirectory(prefix = "delombok").deleteOnExit()) match { - case Success(tempDir) => - ExternalCommand.run(delombokToTempDirCommand(tempDir, analysisJavaHome), cwd = projectDir) match { - case Success(_) => - tempDir.path.toAbsolutePath.toString + def run(projectDir: String, analysisJavaHome: Option[String]): String = + Try(File.newTemporaryDirectory(prefix = "delombok").deleteOnExit()) match + case Success(tempDir) => + ExternalCommand.run( + delombokToTempDirCommand(tempDir, analysisJavaHome), + cwd = projectDir + ) match + case Success(_) => + tempDir.path.toAbsolutePath.toString - case Failure(t) => - logger.debug(s"Executing delombok failed", t) - logger.debug( - "Creating AST with original source instead. Some methods and type information will be missing." - ) - projectDir - } + case Failure(t) => + logger.debug(s"Executing delombok failed", t) + logger.debug( + "Creating AST with original source instead. Some methods and type information will be missing." + ) + projectDir - case Failure(e) => - logger.debug(s"Failed to create temporary directory for delomboked source. Methods and types may be missing", e) - projectDir - } - } + case Failure(e) => + logger.debug( + s"Failed to create temporary directory for delomboked source. Methods and types may be missing", + e + ) + projectDir - def parseDelombokModeOption(delombokModeStr: Option[String]): DelombokMode = { - delombokModeStr.map(_.toLowerCase) match { - case None => Default - case Some("no-delombok") => NoDelombok - case Some("default") => Default - case Some("types-only") => TypesOnly - case Some("run-delombok") => RunDelombok - case Some(value) => - logger.debug(s"Found unrecognised delombok mode `$value`. Using default instead.") - Default - } - } -} + def parseDelombokModeOption(delombokModeStr: Option[String]): DelombokMode = + delombokModeStr.map(_.toLowerCase) match + case None => Default + case Some("no-delombok") => NoDelombok + case Some("default") => Default + case Some("types-only") => TypesOnly + case Some("run-delombok") => RunDelombok + case Some(value) => + logger.debug(s"Found unrecognised delombok mode `$value`. Using default instead.") + Default +end Delombok diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala index 34500108..d1020e2f 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala @@ -1,7 +1,6 @@ package io.appthreat.javasrc2cpg.util -object NameConstants { - val Super: String = "super" - val This: String = "this" - val WildcardImportName: String = "*" -} +object NameConstants: + val Super: String = "super" + val This: String = "this" + val WildcardImportName: String = "*" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala index f4b46221..32849232 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala @@ -2,7 +2,7 @@ package io.appthreat.javasrc2cpg.util import better.files.File import Delombok.DelombokMode -import Delombok.DelombokMode._ +import Delombok.DelombokMode.* import io.appthreat.x2cpg.SourceFiles import com.github.javaparser.{JavaParser, ParserConfiguration} import com.github.javaparser.ParserConfiguration.LanguageLevel @@ -12,7 +12,7 @@ import io.appthreat.javasrc2cpg.{Config, JavaSrc2Cpg} import org.slf4j.LoggerFactory import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.jdk.OptionConverters.RichOptional import java.nio.file.Path import java.nio.file.Paths @@ -23,122 +23,121 @@ import java.nio.charset.StandardCharsets import scala.util.Try import scala.util.Success -class SourceParser private (originalInputPath: Path, analysisRoot: Path, typesRoot: Path) { - - private val logger = LoggerFactory.getLogger(this.getClass) - - /** Parse the given file into a JavaParser CompliationUnit that will be used for creating the CPG AST. - * - * @param relativeFilename - * path to the input file relative to the project root. - */ - def parseAnalysisFile(relativeFilename: String): Option[CompilationUnit] = { - val analysisFilename = analysisRoot.resolve(relativeFilename).toString - // Need to store tokens for position information. - fileIfExists(analysisFilename).flatMap(parse(_, storeTokens = true)) - } - - /** Parse the given file into a JavaParser CompliationUnit that will be used for reading type information. These - * should not be used for determining the structure of the AST. - * - * @param relativeFilename - * path to the input file relative to the project root. - */ - def parseTypesFile(relativeFilename: String): Option[CompilationUnit] = { - val typesFilename = typesRoot.resolve(relativeFilename).toString - fileIfExists(typesFilename).flatMap(parse(_, storeTokens = false)) - } - - def fileIfExists(filename: String): Option[File] = { - val file = File(filename) - - Option.when(file.exists)(file) - } - - def getTypesFileLines(relativeFilename: String): Try[Iterable[String]] = { - val typesFilename = typesRoot.resolve(relativeFilename).toString - Try(File(typesFilename).lines(Charset.defaultCharset())) - .orElse(Try(File(typesFilename).lines(StandardCharsets.ISO_8859_1))) - } - - def doesTypesFileExist(relativeFilename: String): Boolean = { - File(typesRoot.resolve(relativeFilename)).isRegularFile - } - - private def parse(file: File, storeTokens: Boolean): Option[CompilationUnit] = { - val javaParserConfig = - new ParserConfiguration() - .setLanguageLevel(LanguageLevel.BLEEDING_EDGE) - .setAttributeComments(false) - .setLexicalPreservationEnabled(true) - .setStoreTokens(storeTokens) - val parseResult = new JavaParser(javaParserConfig).parse(file.toJava) - - parseResult.getProblems.asScala.toList match { - case Nil => // Just carry on as usual - case problems => - logger.debug(s"Encountered problems while parsing file ${file.name}:") - problems.foreach { problem => - logger.debug(s"- ${problem.getMessage}") - } - } - - parseResult.getResult.toScala match { - case Some(result) if result.getParsed == Parsedness.PARSED => Some(result) - case _ => - logger.debug(s"Failed to parse file ${file.name}") - None - } - } -} - -object SourceParser { - private val logger = LoggerFactory.getLogger(this.getClass) - - def apply(config: Config, hasLombokDependency: Boolean): SourceParser = { - val canonicalInputPath = File(config.inputPath).canonicalPath - val (analysisDir, typesDir) = - getAnalysisAndTypesDirs(canonicalInputPath, config.delombokJavaHome, config.delombokMode, hasLombokDependency) - - new SourceParser(Path.of(canonicalInputPath), Path.of(analysisDir), Path.of(typesDir)) - } - - def getSourceFilenames(config: Config, sourcesOverride: Option[List[String]] = None): Array[String] = { - val inputPaths = sourcesOverride.getOrElse(config.inputPath :: Nil).toSet - SourceFiles.determine(inputPaths, JavaSrc2Cpg.sourceFileExtensions, config).toArray - } - - /** Implements the logic described in the option description for the "delombok-mode" option: - * - no-delombok: do not run delombok. - * - default: run delombok if a lombok dependency is found and analyse delomboked code. - * - types-only: run delombok, but use it for type information only - * - run-delombok: run delombok and analyse delomboked code - * - * @return - * the tuple (analysisRoot, typesRoot) where analysisRoot is used to locate source files for creating the AST and - * typesRoot is used for locating source files from which to extract type information. - */ - private def getAnalysisAndTypesDirs( - originalDir: String, - delombokJavaHome: Option[String], - delombokMode: Option[String], - hasLombokDependency: Boolean - ): (String, String) = { - lazy val delombokDir = Delombok.run(originalDir, delombokJavaHome) - - Delombok.parseDelombokModeOption(delombokMode) match { - case Default if hasLombokDependency => - logger.debug(s"Analysing delomboked code as lombok dependency was found.") - (delombokDir, delombokDir) - - case Default => (originalDir, originalDir) - - case NoDelombok => (originalDir, originalDir) - - case TypesOnly => (originalDir, delombokDir) - - case RunDelombok => (delombokDir, delombokDir) - } - } - -} +class SourceParser private (originalInputPath: Path, analysisRoot: Path, typesRoot: Path): + + private val logger = LoggerFactory.getLogger(this.getClass) + + /** Parse the given file into a JavaParser CompliationUnit that will be used for creating the + * CPG AST. + * + * @param relativeFilename + * path to the input file relative to the project root. + */ + def parseAnalysisFile(relativeFilename: String): Option[CompilationUnit] = + val analysisFilename = analysisRoot.resolve(relativeFilename).toString + // Need to store tokens for position information. + fileIfExists(analysisFilename).flatMap(parse(_, storeTokens = true)) + + /** Parse the given file into a JavaParser CompliationUnit that will be used for reading type + * information. These should not be used for determining the structure of the AST. + * + * @param relativeFilename + * path to the input file relative to the project root. + */ + def parseTypesFile(relativeFilename: String): Option[CompilationUnit] = + val typesFilename = typesRoot.resolve(relativeFilename).toString + fileIfExists(typesFilename).flatMap(parse(_, storeTokens = false)) + + def fileIfExists(filename: String): Option[File] = + val file = File(filename) + + Option.when(file.exists)(file) + + def getTypesFileLines(relativeFilename: String): Try[Iterable[String]] = + val typesFilename = typesRoot.resolve(relativeFilename).toString + Try(File(typesFilename).lines(Charset.defaultCharset())) + .orElse(Try(File(typesFilename).lines(StandardCharsets.ISO_8859_1))) + + def doesTypesFileExist(relativeFilename: String): Boolean = + File(typesRoot.resolve(relativeFilename)).isRegularFile + + private def parse(file: File, storeTokens: Boolean): Option[CompilationUnit] = + val javaParserConfig = + new ParserConfiguration() + .setLanguageLevel(LanguageLevel.BLEEDING_EDGE) + .setAttributeComments(false) + .setLexicalPreservationEnabled(true) + .setStoreTokens(storeTokens) + val parseResult = new JavaParser(javaParserConfig).parse(file.toJava) + + parseResult.getProblems.asScala.toList match + case Nil => // Just carry on as usual + case problems => + logger.debug(s"Encountered problems while parsing file ${file.name}:") + problems.foreach { problem => + logger.debug(s"- ${problem.getMessage}") + } + + parseResult.getResult.toScala match + case Some(result) if result.getParsed == Parsedness.PARSED => Some(result) + case _ => + logger.debug(s"Failed to parse file ${file.name}") + None + end parse +end SourceParser + +object SourceParser: + private val logger = LoggerFactory.getLogger(this.getClass) + + def apply(config: Config, hasLombokDependency: Boolean): SourceParser = + val canonicalInputPath = File(config.inputPath).canonicalPath + val (analysisDir, typesDir) = + getAnalysisAndTypesDirs( + canonicalInputPath, + config.delombokJavaHome, + config.delombokMode, + hasLombokDependency + ) + + new SourceParser(Path.of(canonicalInputPath), Path.of(analysisDir), Path.of(typesDir)) + + def getSourceFilenames( + config: Config, + sourcesOverride: Option[List[String]] = None + ): Array[String] = + val inputPaths = sourcesOverride.getOrElse(config.inputPath :: Nil).toSet + SourceFiles.determine(inputPaths, JavaSrc2Cpg.sourceFileExtensions, config).toArray + + /** Implements the logic described in the option description for the "delombok-mode" option: + * - no-delombok: do not run delombok. + * - default: run delombok if a lombok dependency is found and analyse delomboked code. + * - types-only: run delombok, but use it for type information only + * - run-delombok: run delombok and analyse delomboked code + * + * @return + * the tuple (analysisRoot, typesRoot) where analysisRoot is used to locate source files for + * creating the AST and typesRoot is used for locating source files from which to extract + * type information. + */ + private def getAnalysisAndTypesDirs( + originalDir: String, + delombokJavaHome: Option[String], + delombokMode: Option[String], + hasLombokDependency: Boolean + ): (String, String) = + lazy val delombokDir = Delombok.run(originalDir, delombokJavaHome) + + Delombok.parseDelombokModeOption(delombokMode) match + case Default if hasLombokDependency => + logger.debug(s"Analysing delomboked code as lombok dependency was found.") + (delombokDir, delombokDir) + + case Default => (originalDir, originalDir) + + case NoDelombok => (originalDir, originalDir) + + case TypesOnly => (originalDir, delombokDir) + + case RunDelombok => (delombokDir, delombokDir) + end getAnalysisAndTypesDirs +end SourceParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala index 25115695..ee80e2f1 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala @@ -2,96 +2,87 @@ package io.appthreat.javasrc2cpg.util import better.files.File -/** Given a `codeDirectory`, the SourceRootFinder attempts to find all the source roots (so that the subdirectories - * match the package structure of the source files). The JavaParserTypeSolver is path-dependent, so without this the - * user would need to ensure that they specify the correct source directory when running javasrc2cpg. +/** Given a `codeDirectory`, the SourceRootFinder attempts to find all the source roots (so that the + * subdirectories match the package structure of the source files). The JavaParserTypeSolver is + * path-dependent, so without this the user would need to ensure that they specify the correct + * source directory when running javasrc2cpg. * - * This works by checking if any subdirectories of the given directory match common java directory structures. In order - * of preference: * Maven's default structure, e.g. src/main/java/io/... and src/test/java/io/... * Top-level src and - * test directories, e.g. src/io/... and test/io/... If neither of these match, then it defaults to using the given - * user input as the source root. + * This works by checking if any subdirectories of the given directory match common java directory + * structures. In order of preference: * Maven's default structure, e.g. src/main/java/io/... and + * src/test/java/io/... * Top-level src and test directories, e.g. src/io/... and test/io/... If + * neither of these match, then it defaults to using the given user input as the source root. */ -object SourceRootFinder { +object SourceRootFinder: - private val excludes: Set[String] = Set("test", ".mvn", ".git") + private val excludes: Set[String] = Set("test", ".mvn", ".git") - private def statePreSrc(currentDir: File): List[File] = { - currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match { - case name if excludes.contains(name) => Nil - case "src" => stateSrc(child) - case "main" => statePostSrc(child) - case "java" => child :: Nil - case _ => statePreSrc(child) - } - } - } + private def statePreSrc(currentDir: File): List[File] = + currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "src" => stateSrc(child) + case "main" => statePostSrc(child) + case "java" => child :: Nil + case _ => statePreSrc(child) + } - private def stateSrc(currentDir: File): List[File] = { - val mainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match { - case name if excludes.contains(name) => Nil - case "main" => statePostSrc(child) - case _ => Nil - } - } + private def stateSrc(currentDir: File): List[File] = + val mainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "main" => statePostSrc(child) + case _ => Nil + } - val nonMainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match { - case name if excludes.contains(name) => Nil - case "main" => Nil - case _ => statePostSrc(child) - } - } + val nonMainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "main" => Nil + case _ => statePostSrc(child) + } - val hasExcludedDir = currentDir.children.filter(_.isDirectory).toList.exists(file => excludes.contains(file.name)) + val hasExcludedDir = currentDir.children.filter(_.isDirectory).toList.exists(file => + excludes.contains(file.name) + ) - (mainChildren, nonMainChildren) match { - case (Nil, Nil) => - // probably a src/test directory in a tests-only module - if (hasExcludedDir) Nil else List(currentDir) - case (mainC, Nil) => - // probably follows the common src/main/ structure - mainC - case (Nil, _) => - // the non-main children are probably package roots - List(currentDir) - case (mainC, nonMC) => - // main a package root here? - mainC ++ nonMC - } - } + (mainChildren, nonMainChildren) match + case (Nil, Nil) => + // probably a src/test directory in a tests-only module + if hasExcludedDir then Nil else List(currentDir) + case (mainC, Nil) => + // probably follows the common src/main/ structure + mainC + case (Nil, _) => + // the non-main children are probably package roots + List(currentDir) + case (mainC, nonMC) => + // main a package root here? + mainC ++ nonMC + end stateSrc - private def statePostSrc(currentDir: File): List[File] = { - val javaChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match { - case name if excludes.contains(name) => Nil - case _ => child :: Nil - } - } + private def statePostSrc(currentDir: File): List[File] = + val javaChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case _ => child :: Nil + } - javaChildren match { - case Nil => currentDir :: Nil - case _ => javaChildren - } - } + javaChildren match + case Nil => currentDir :: Nil + case _ => javaChildren - private def listBottomLevelSubdirectories(currentDir: File): List[File] = { - val srcDirs = currentDir.name match { - case name if excludes.contains(name) => Nil - case "src" => stateSrc(currentDir) - case "main" => statePostSrc(currentDir) - case "java" => List(currentDir) - case _ => statePreSrc(currentDir) - } + private def listBottomLevelSubdirectories(currentDir: File): List[File] = + val srcDirs = currentDir.name match + case name if excludes.contains(name) => Nil + case "src" => stateSrc(currentDir) + case "main" => statePostSrc(currentDir) + case "java" => List(currentDir) + case _ => statePreSrc(currentDir) - srcDirs match { - case Nil => List(currentDir) - case _ => srcDirs - } - } + srcDirs match + case Nil => List(currentDir) + case _ => srcDirs - def getSourceRoots(codeDir: String): List[String] = { - listBottomLevelSubdirectories(File(codeDir)).map(_.pathAsString) - } -} + def getSourceRoots(codeDir: String): List[String] = + listBottomLevelSubdirectories(File(codeDir)).map(_.pathAsString) +end SourceRootFinder diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala index 4db9f838..3033d8ac 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala @@ -10,54 +10,57 @@ import org.slf4j.LoggerFactory import scala.collection.mutable import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters._ - -object Util { - - private val logger = LoggerFactory.getLogger(this.getClass) - def composeMethodFullName(typeDeclFullName: String, name: String, signature: String): String = { - s"$typeDeclFullName.$name:$signature" - } - - def safeGetAncestors(typeDecl: ResolvedReferenceTypeDeclaration): Seq[ResolvedReferenceType] = { - Try(typeDecl.getAncestors(true)) match { - case Success(ancestors) => ancestors.asScala.filterNot(_ == typeDecl).toSeq - - case Failure(exception) => - logger.debug(s"Failed to get direct parents for typeDecl ${typeDecl.getQualifiedName}", exception) - Seq.empty - } - } - - def getAllParents(typeDecl: ResolvedReferenceTypeDeclaration): mutable.ArrayBuffer[ResolvedReferenceType] = { - val result = mutable.ArrayBuffer.empty[ResolvedReferenceType] - - if (!typeDecl.isJavaLangObject) { - safeGetAncestors(typeDecl).filter(_.getQualifiedName != typeDecl.getQualifiedName).foreach { ancestor => - result.append(ancestor) - getAllParents(ancestor, result) - } - } - - result - } - - def composeMethodLikeSignature(returnType: String, parameterTypes: collection.Seq[String]): String = { - s"$returnType(${parameterTypes.mkString(",")})" - } - - def composeUnresolvedSignature(paramCount: Int): String = { - s"${Defines.UnresolvedSignature}($paramCount)" - } - - private def getAllParents(typ: ResolvedReferenceType, result: mutable.ArrayBuffer[ResolvedReferenceType]): Unit = { - if (typ.isJavaLangObject) { - Iterable.empty - } else { - Try(typ.getDirectAncestors).map(_.asScala).getOrElse(Nil).foreach { ancestor => - result.append(ancestor) - getAllParents(ancestor, result) - } - } - } -} +import scala.jdk.CollectionConverters.* + +object Util: + + private val logger = LoggerFactory.getLogger(this.getClass) + def composeMethodFullName(typeDeclFullName: String, name: String, signature: String): String = + s"$typeDeclFullName.$name:$signature" + + def safeGetAncestors(typeDecl: ResolvedReferenceTypeDeclaration): Seq[ResolvedReferenceType] = + Try(typeDecl.getAncestors(true)) match + case Success(ancestors) => ancestors.asScala.filterNot(_ == typeDecl).toSeq + + case Failure(exception) => + logger.debug( + s"Failed to get direct parents for typeDecl ${typeDecl.getQualifiedName}", + exception + ) + Seq.empty + + def getAllParents(typeDecl: ResolvedReferenceTypeDeclaration) + : mutable.ArrayBuffer[ResolvedReferenceType] = + val result = mutable.ArrayBuffer.empty[ResolvedReferenceType] + + if !typeDecl.isJavaLangObject then + safeGetAncestors(typeDecl).filter( + _.getQualifiedName != typeDecl.getQualifiedName + ).foreach { ancestor => + result.append(ancestor) + getAllParents(ancestor, result) + } + + result + + def composeMethodLikeSignature( + returnType: String, + parameterTypes: collection.Seq[String] + ): String = + s"$returnType(${parameterTypes.mkString(",")})" + + def composeUnresolvedSignature(paramCount: Int): String = + s"${Defines.UnresolvedSignature}($paramCount)" + + private def getAllParents( + typ: ResolvedReferenceType, + result: mutable.ArrayBuffer[ResolvedReferenceType] + ): Unit = + if typ.isJavaLangObject then + Iterable.empty + else + Try(typ.getDirectAncestors).map(_.asScala).getOrElse(Nil).foreach { ancestor => + result.append(ancestor) + getAllParents(ancestor, result) + } +end Util diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala index ef87c515..2bea35db 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala @@ -1,7 +1,12 @@ package io.appthreat.jimple2cpg import better.files.File -import io.appthreat.jimple2cpg.passes.{AstCreationPass, ConfigFileCreationPass, DeclarationRefPass, SootAstCreationPass} +import io.appthreat.jimple2cpg.passes.{ + AstCreationPass, + ConfigFileCreationPass, + DeclarationRefPass, + SootAstCreationPass +} import io.appthreat.jimple2cpg.util.ProgramHandlingUtil.ClassFile import io.appthreat.jimple2cpg.util.ProgramHandlingUtil.{ClassFile, extractClassesInPackageLayout} import io.appthreat.x2cpg.X2Cpg.withNewEmptyCpg @@ -17,144 +22,136 @@ import scala.jdk.CollectionConverters.{EnumerationHasAsScala, SeqHasAsJava} import scala.language.postfixOps import scala.util.Try -object Jimple2Cpg { - val language = "JAVA" - - def apply(): Jimple2Cpg = new Jimple2Cpg() -} - -class Jimple2Cpg extends X2CpgFrontend[Config] { - - import Jimple2Cpg.* - - private val logger = LoggerFactory.getLogger(classOf[Jimple2Cpg]) - - private def sootLoadApk(input: File, framework: Option[String] = None): Unit = { - Options.v().set_process_dir(List(input.canonicalPath).asJava) - framework match { - case Some(value) if value.nonEmpty => - Options.v().set_src_prec(Options.src_prec_apk) - Options.v().set_force_android_jar(value) - case _ => - Options.v().set_src_prec(Options.src_prec_apk_c_j) - } - Options.v().set_process_multiple_dex(true) - // workaround for Soot's bug while parsing large apk. - // see: https://github.com/soot-oss/soot/issues/1256 - Options.v().setPhaseOption("jb", "use-original-names:false") - } - - /** Load all class files from archives or directories recursively - * @param recurse - * Whether to unpack recursively - * @return - * The list of extracted class files whose package path could be extracted, placed on that package path relative to - * [[tmpDir]] - */ - private def loadClassFiles(src: File, tmpDir: File, recurse: Boolean): List[ClassFile] = { - val archiveFileExtensions = Set(".jar", ".war", ".zip") - extractClassesInPackageLayout( - src, - tmpDir, - isClass = e => e.extension.contains(".class"), - isArchive = e => e.extension.exists(archiveFileExtensions.contains), - recurse - ) - } - - /** Extract all class files found, place them in their package layout and load them into soot. - * @param input - * The file/directory to traverse for class files. - * @param tmpDir - * The directory to place the class files in their package layout - * @param recurse - * Whether to unpack recursively - */ - private def sootLoad(input: File, tmpDir: File, recurse: Boolean): List[ClassFile] = { - Options.v().set_soot_classpath(tmpDir.canonicalPath) - Options.v().set_prepend_classpath(true) - val classFiles = loadClassFiles(input, tmpDir, recurse) - val fullyQualifiedClassNames = classFiles.flatMap(_.fullyQualifiedClassName) - logger.debug(s"Loading ${classFiles.size} program files") - logger.debug(s"Source files are: ${classFiles.map(_.file.canonicalPath)}") - fullyQualifiedClassNames.foreach { fqcn => - Scene.v().addBasicClass(fqcn) - Scene.v().loadClassAndSupport(fqcn) - } - classFiles - } - - /** Apply the soot passes - * @param tmpDir - * A temporary directory that will be used as the classpath for extracted class files - */ - private def cpgApplyPasses(cpg: Cpg, config: Config, tmpDir: File): Unit = { - val input = File(config.inputPath) - configureSoot(config, tmpDir) - new MetaDataPass(cpg, language, config.inputPath).createAndApply() - - val globalFromAstCreation: () => Global = input.extension match { - case Some(".apk" | ".dex") if input.isRegularFile => - sootLoadApk(input, config.android) - { () => - val astCreator = SootAstCreationPass(cpg, config) - astCreator.createAndApply() - astCreator.global +object Jimple2Cpg: + val language = "JAVA" + + def apply(): Jimple2Cpg = new Jimple2Cpg() + +class Jimple2Cpg extends X2CpgFrontend[Config]: + + import Jimple2Cpg.* + + private val logger = LoggerFactory.getLogger(classOf[Jimple2Cpg]) + + private def sootLoadApk(input: File, framework: Option[String] = None): Unit = + Options.v().set_process_dir(List(input.canonicalPath).asJava) + framework match + case Some(value) if value.nonEmpty => + Options.v().set_src_prec(Options.src_prec_apk) + Options.v().set_force_android_jar(value) + case _ => + Options.v().set_src_prec(Options.src_prec_apk_c_j) + Options.v().set_process_multiple_dex(true) + // workaround for Soot's bug while parsing large apk. + // see: https://github.com/soot-oss/soot/issues/1256 + Options.v().setPhaseOption("jb", "use-original-names:false") + + /** Load all class files from archives or directories recursively + * @param recurse + * Whether to unpack recursively + * @return + * The list of extracted class files whose package path could be extracted, placed on that + * package path relative to [[tmpDir]] + */ + private def loadClassFiles(src: File, tmpDir: File, recurse: Boolean): List[ClassFile] = + val archiveFileExtensions = Set(".jar", ".war", ".zip") + extractClassesInPackageLayout( + src, + tmpDir, + isClass = e => e.extension.contains(".class"), + isArchive = e => e.extension.exists(archiveFileExtensions.contains), + recurse + ) + + /** Extract all class files found, place them in their package layout and load them into soot. + * @param input + * The file/directory to traverse for class files. + * @param tmpDir + * The directory to place the class files in their package layout + * @param recurse + * Whether to unpack recursively + */ + private def sootLoad(input: File, tmpDir: File, recurse: Boolean): List[ClassFile] = + Options.v().set_soot_classpath(tmpDir.canonicalPath) + Options.v().set_prepend_classpath(true) + val classFiles = loadClassFiles(input, tmpDir, recurse) + val fullyQualifiedClassNames = classFiles.flatMap(_.fullyQualifiedClassName) + logger.debug(s"Loading ${classFiles.size} program files") + logger.debug(s"Source files are: ${classFiles.map(_.file.canonicalPath)}") + fullyQualifiedClassNames.foreach { fqcn => + Scene.v().addBasicClass(fqcn) + Scene.v().loadClassAndSupport(fqcn) } - case _ => - val classFiles = sootLoad(input, tmpDir, config.recurse) - { () => - val astCreator = AstCreationPass(classFiles, cpg, config) - astCreator.createAndApply() - astCreator.global - } - } - logger.debug("Loading classes to soot") - Scene.v().loadNecessaryClasses() - logger.debug(s"Loaded ${Scene.v().getApplicationClasses.size()} classes") - - val global = globalFromAstCreation() - TypeNodePass - .withRegisteredTypes(global.usedTypes.keys().asScala.toList, cpg) - .createAndApply() - DeclarationRefPass(cpg).createAndApply() - new ConfigFileCreationPass(cpg).createAndApply() - } - - override def createCpg(config: Config): Try[Cpg] = - try { - withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => - File.temporaryDirectory("jimple2cpg-").apply { tmpDir => - cpgApplyPasses(cpg, config, tmpDir) - } - } - } finally { - G.reset() - } - - private def configureSoot(config: Config, outDir: File): Unit = { - // set application mode - Options.v().set_app(false) - Options.v().set_whole_program(false) - // keep debugging info - Options.v().set_keep_line_number(true) - Options.v().set_keep_offset(true) - // ignore library code - Options.v().set_no_bodies_for_excluded(true) - Options.v().set_allow_phantom_refs(true) - // keep variable names - Options.v().setPhaseOption("jb.sils", "enabled:false") - Options.v().setPhaseOption("jb", "use-original-names:true") - // output jimple - Options.v().set_output_format(Options.output_format_jimple) - Options.v().set_output_dir(outDir.canonicalPath) - - Options.v().set_dynamic_dir(config.dynamicDirs.asJava) - Options.v().set_dynamic_package(config.dynamicPkgs.asJava) - - if (config.fullResolver) { - // full transitive resolution of all references - Options.v().set_full_resolver(true) - } - } -} + classFiles + + /** Apply the soot passes + * @param tmpDir + * A temporary directory that will be used as the classpath for extracted class files + */ + private def cpgApplyPasses(cpg: Cpg, config: Config, tmpDir: File): Unit = + val input = File(config.inputPath) + configureSoot(config, tmpDir) + new MetaDataPass(cpg, language, config.inputPath).createAndApply() + + val globalFromAstCreation: () => Global = input.extension match + case Some(".apk" | ".dex") if input.isRegularFile => + sootLoadApk(input, config.android) + { () => + val astCreator = SootAstCreationPass(cpg, config) + astCreator.createAndApply() + astCreator.global + } + case _ => + val classFiles = sootLoad(input, tmpDir, config.recurse) + { () => + val astCreator = AstCreationPass(classFiles, cpg, config) + astCreator.createAndApply() + astCreator.global + } + logger.debug("Loading classes to soot") + Scene.v().loadNecessaryClasses() + logger.debug(s"Loaded ${Scene.v().getApplicationClasses.size()} classes") + + val global = globalFromAstCreation() + TypeNodePass + .withRegisteredTypes(global.usedTypes.keys().asScala.toList, cpg) + .createAndApply() + DeclarationRefPass(cpg).createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + end cpgApplyPasses + + override def createCpg(config: Config): Try[Cpg] = + try + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + File.temporaryDirectory("jimple2cpg-").apply { tmpDir => + cpgApplyPasses(cpg, config, tmpDir) + } + } + finally + G.reset() + + private def configureSoot(config: Config, outDir: File): Unit = + // set application mode + Options.v().set_app(false) + Options.v().set_whole_program(false) + // keep debugging info + Options.v().set_keep_line_number(true) + Options.v().set_keep_offset(true) + // ignore library code + Options.v().set_no_bodies_for_excluded(true) + Options.v().set_allow_phantom_refs(true) + // keep variable names + Options.v().setPhaseOption("jb.sils", "enabled:false") + Options.v().setPhaseOption("jb", "use-original-names:true") + // output jimple + Options.v().set_output_format(Options.output_format_jimple) + Options.v().set_output_dir(outDir.canonicalPath) + + Options.v().set_dynamic_dir(config.dynamicDirs.asJava) + Options.v().set_dynamic_package(config.dynamicPkgs.asJava) + + if config.fullResolver then + // full transitive resolution of all references + Options.v().set_full_resolver(true) + end configureSoot +end Jimple2Cpg diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala index 5568f61d..88bd9b2e 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala @@ -1,6 +1,6 @@ package io.appthreat.jimple2cpg -import Frontend._ +import Frontend.* import io.appthreat.x2cpg.{X2CpgConfig, X2CpgMain} import scopt.OParser @@ -12,64 +12,57 @@ final case class Config( dynamicPkgs: Seq[String] = Seq.empty, fullResolver: Boolean = false, recurse: Boolean = false -) extends X2CpgConfig[Config] { - def withAndroid(android: String): Config = { - copy(android = Some(android)).withInheritedFields(this) - } - def withDynamicDirs(value: Seq[String]): Config = { - copy(dynamicDirs = value).withInheritedFields(this) - } - def withDynamicPkgs(value: Seq[String]): Config = { - copy(dynamicPkgs = value).withInheritedFields(this) - } - def withFullResolver(value: Boolean): Config = { - copy(fullResolver = value).withInheritedFields(this) - } +) extends X2CpgConfig[Config]: + def withAndroid(android: String): Config = + copy(android = Some(android)).withInheritedFields(this) + def withDynamicDirs(value: Seq[String]): Config = + copy(dynamicDirs = value).withInheritedFields(this) + def withDynamicPkgs(value: Seq[String]): Config = + copy(dynamicPkgs = value).withInheritedFields(this) + def withFullResolver(value: Boolean): Config = + copy(fullResolver = value).withInheritedFields(this) - def withRecurse(value: Boolean): Config = { - copy(recurse = value) - } + def withRecurse(value: Boolean): Config = + copy(recurse = value) -} +private object Frontend: -private object Frontend { + implicit val defaultConfig: Config = Config() - implicit val defaultConfig: Config = Config() - - val cmdLineParser: OParser[Unit, Config] = { - val builder = OParser.builder[Config] - import builder._ - OParser.sequence( - programName("jimple2cpg"), - opt[String]("android") - .text("Optional path to android.jar while processing apk file.") - .action((android, config) => config.withAndroid(android)), - opt[Unit]("full-resolver") - .text("enables full transitive resolution of all references found in all classes that are resolved") - .action((_, config) => config.withFullResolver(true)), - opt[Unit]("recurse") - .text("recursively unpack jars") - .action((_, config) => config.withRecurse(true)), - opt[Seq[String]]("dynamic-dirs") - .valueName(",,...") - .text( - "Mark all class files in dirs as classes that may be loaded dynamically. Comma separated values for multiple directories." - ) - .action((dynamicDirs, config) => config.withDynamicDirs(dynamicDirs)), - opt[Seq[String]]("dynamic-pkgs") - .valueName(",,...") - .text( - "Marks all class files belonging to the package pkg or any of its subpackages as classes which the application may load dynamically. Comma separated values for multiple packages." + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("jimple2cpg"), + opt[String]("android") + .text("Optional path to android.jar while processing apk file.") + .action((android, config) => config.withAndroid(android)), + opt[Unit]("full-resolver") + .text( + "enables full transitive resolution of all references found in all classes that are resolved" + ) + .action((_, config) => config.withFullResolver(true)), + opt[Unit]("recurse") + .text("recursively unpack jars") + .action((_, config) => config.withRecurse(true)), + opt[Seq[String]]("dynamic-dirs") + .valueName(",,...") + .text( + "Mark all class files in dirs as classes that may be loaded dynamically. Comma separated values for multiple directories." + ) + .action((dynamicDirs, config) => config.withDynamicDirs(dynamicDirs)), + opt[Seq[String]]("dynamic-pkgs") + .valueName(",,...") + .text( + "Marks all class files belonging to the package pkg or any of its subpackages as classes which the application may load dynamically. Comma separated values for multiple packages." + ) + .action((dynamicPkgs, config) => config.withDynamicPkgs(dynamicPkgs)) ) - .action((dynamicPkgs, config) => config.withDynamicPkgs(dynamicPkgs)) - ) - } -} + end cmdLineParser +end Frontend /** Entry point for command line CPG creator */ -object Main extends X2CpgMain(cmdLineParser, new Jimple2Cpg()) { - def run(config: Config, jimple2Cpg: Jimple2Cpg): Unit = { - jimple2Cpg.run(config) - } -} +object Main extends X2CpgMain(cmdLineParser, new Jimple2Cpg()): + def run(config: Config, jimple2Cpg: Jimple2Cpg): Unit = + jimple2Cpg.run(config) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala index f2035818..5ac2f11f 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala @@ -8,31 +8,31 @@ import io.shiftleft.passes.ConcurrentWriterCpgPass import org.slf4j.LoggerFactory import soot.Scene -/** Creates the AST layer from the given class file and stores all types in the given global parameter. +/** Creates the AST layer from the given class file and stores all types in the given global + * parameter. * @param classFiles * List of class files and their fully qualified class names * @param cpg * The CPG to add to */ class AstCreationPass(classFiles: List[ClassFile], cpg: Cpg, config: Config) - extends ConcurrentWriterCpgPass[ClassFile](cpg) { + extends ConcurrentWriterCpgPass[ClassFile](cpg): - val global: Global = new Global() - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + val global: Global = new Global() + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - override def generateParts(): Array[_ <: AnyRef] = classFiles.toArray + override def generateParts(): Array[? <: AnyRef] = classFiles.toArray - override def runOnPart(builder: DiffGraphBuilder, classFile: ClassFile): Unit = { - try { - val sootClass = Scene.v().loadClassAndSupport(classFile.fullyQualifiedClassName.get) - sootClass.setApplicationClass() - val localDiff = AstCreator(classFile.file.canonicalPath, sootClass, global)(config.schemaValidation).createAst() - builder.absorb(localDiff) - } catch { - case e: Exception => - logger.warn(s"Exception on AST creation for ${classFile.file.canonicalPath}", e) - Iterator() - } - } - -} + override def runOnPart(builder: DiffGraphBuilder, classFile: ClassFile): Unit = + try + val sootClass = Scene.v().loadClassAndSupport(classFile.fullyQualifiedClassName.get) + sootClass.setApplicationClass() + val localDiff = AstCreator(classFile.file.canonicalPath, sootClass, global)( + config.schemaValidation + ).createAst() + builder.absorb(localDiff) + catch + case e: Exception => + logger.warn(s"Exception on AST creation for ${classFile.file.canonicalPath}", e) + Iterator() +end AstCreationPass diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala index e380ca4f..f18cf4cb 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala @@ -20,1167 +20,1165 @@ import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters.CollectionHasAsScala import scala.util.{Failure, Success, Try} -class AstCreator(filename: String, cls: SootClass, global: Global)(implicit withSchemaValidation: ValidationMode) - extends AstCreatorBase(filename) { - - import AstCreator.* - - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - private val unitToAsts = mutable.HashMap[soot.Unit, Seq[Ast]]() - private val controlTargets = mutable.HashMap[Seq[Ast], soot.Unit]() - // There are many, but the popular ones should do https://en.wikipedia.org/wiki/List_of_JVM_languages - private val JVM_LANGS = HashSet("scala", "clojure", "groovy", "kotlin", "jython", "jruby") - - /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates TYPE nodes for each - * key in the map. - */ - private def registerType(typeName: String): String = { - global.usedTypes.put(typeName, true) - typeName - } - - /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a DiffGraph containing the - * corresponding CPG AST. - */ - def createAst(): DiffGraphBuilder = { - val astRoot = astForCompilationUnit(cls) - storeInDiffGraph(astRoot, diffGraph) - diffGraph - } - - /** Translate compilation unit into AST - */ - private def astForCompilationUnit(cls: SootClass): Ast = { - val ast = astForPackageDeclaration(cls.getPackageName) - val namespaceBlockFullName = - ast.root.collect { case x: NewNamespaceBlock => x.fullName }.getOrElse("none") - ast.withChild(astForTypeDecl(cls.getType, namespaceBlockFullName)) - } - - /** Translate package declaration into AST consisting of a corresponding namespace block. - */ - private def astForPackageDeclaration(packageDecl: String): Ast = { - val absolutePath = new java.io.File(filename).toPath.toAbsolutePath.normalize().toString - val name = packageDecl.split("\\.").lastOption.getOrElse("") - val namespaceBlock = NewNamespaceBlock() - .name(name) - .fullName(packageDecl) - Ast(namespaceBlock.filename(absolutePath).order(1)) - } - - /** Creates a list of all inherited classes and implemented interfaces. If there are none then a list with a single - * element 'java.lang.Object' is returned by default. Returns two lists in the form of (List[Super Classes], - * List[Interfaces]). - */ - private def inheritedAndImplementedClasses(clazz: SootClass): (List[String], List[String]) = { - val implementsTypeFullName = clazz.getInterfaces.asScala.map { (i: SootClass) => - registerType(i.getType.toQuotedString) - }.toList - val inheritsFromTypeFullName = - if (clazz.hasSuperclass && clazz.getSuperclass.getType.toQuotedString != "java.lang.Object") { - List(registerType(clazz.getSuperclass.getType.toQuotedString)) - } else if (implementsTypeFullName.isEmpty) { - List(registerType("java.lang.Object")) - } else List() - - (inheritsFromTypeFullName, implementsTypeFullName) - } - - /** Creates the AST root for type declarations and acts as the entry point for method generation. - */ - private def astForTypeDecl(typ: RefType, namespaceBlockFullName: String): Ast = { - val fullName = registerType(typ.toQuotedString) - val shortName = typ.getSootClass.getShortJavaStyleName - val clz = typ.getSootClass - val code = new mutable.StringBuilder() - - if (clz.isPublic) code.append("public ") - else if (clz.isPrivate) code.append("private ") - if (clz.isStatic) code.append("static ") - if (clz.isFinal) code.append("final ") - if (clz.isInterface) code.append("interface ") - else if (clz.isAbstract) code.append("abstract ") - if (clz.isEnum) code.append("enum ") - if (!clz.isInterface) code.append(s"class $shortName") - else code.append(shortName) - - val modifiers = astsForModifiers(clz) - val (inherited, implemented) = inheritedAndImplementedClasses(typ.getSootClass) - - if (inherited.nonEmpty) code.append(s" extends ${inherited.mkString(", ")}") - if (implemented.nonEmpty) code.append(s" implements ${implemented.mkString(", ")}") - - val typeDecl = NewTypeDecl() - .name(shortName) - .fullName(fullName) - .order(1) // Jimple always has 1 class per file - .filename(filename) - .code(code.toString()) - .inheritsFromTypeFullName(inherited ++ implemented) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(namespaceBlockFullName) - val methodAsts = withOrder(typ.getSootClass.getMethods.asScala.toList.sortWith((x, y) => x.getName > y.getName)) { - (m, order) => - astForMethod(m, typ, order) - } - - val memberAsts = typ.getSootClass.getFields.asScala - .filter(_.isDeclared) - .zipWithIndex - .map { case (v, i) => - astForField(v, i + methodAsts.size + 1) - } - .toList - - Ast(typeDecl) - .withChildren(astsForHostTags(clz)) - .withChildren(memberAsts) - .withChildren(methodAsts) - .withChildren(modifiers) - } - - private def astForField(field: SootField, order: Int): Ast = { - val typeFullName = registerType(field.getType.toQuotedString) - val name = field.getName - val code = if (field.getDeclaration.contains("enum")) name else s"$typeFullName $name" - val annotations = field.getTags.asScala - .collect { case x: VisibilityAnnotationTag => x } - .flatMap(_.getAnnotations.asScala) - - Ast( - NewMember() - .name(name) - .lineNumber(line(field)) - .columnNumber(column(field)) - .typeFullName(typeFullName) - .order(order) - .code(code) - ).withChildren(withOrder(annotations) { (a, aOrder) => astsForAnnotations(a, aOrder, field) }) - } - - private def astForMethod(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int): Ast = { - val methodNode = createMethodNode(methodDeclaration, typeDecl, childNum) - try { - if (!methodDeclaration.isConcrete) { - // Soot is not able to parse origin parameter names of abstract methods - // https://github.com/soot-oss/soot/issues/1517 - val locals = methodDeclaration.getParameterTypes.asScala.zipWithIndex - .map { case (typ, index) => new JimpleLocal(s"param${index + 1}", typ) } - val parameterAsts = - Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ - withOrder(locals) { (p, order) => astForParameter(p, order, methodDeclaration, Map()) } - Ast(methodNode) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(parameterAsts) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(Ast(NewBlock())) - .withChild(astForMethodReturn(methodDeclaration)) - } else { - val lastOrder = 2 + methodDeclaration.getParameterCount - // Map params to their annotations - val mTags = methodDeclaration.getTags.asScala - val paramAnnos = - mTags.collect { case x: VisibilityParameterAnnotationTag => x }.flatMap(_.getVisibilityAnnotations.asScala) - val paramNames = mTags.collect { case x: ParamNamesTag => x }.flatMap(_.getNames.asScala) - val parameterAnnotations = paramNames.zip(paramAnnos).filter(_._2 != null).toMap - val methodBody = Try(methodDeclaration.getActiveBody) match { - case Failure(_) => methodDeclaration.retrieveActiveBody() - case Success(body) => body +class AstCreator(filename: String, cls: SootClass, global: Global)(implicit + withSchemaValidation: ValidationMode +) extends AstCreatorBase(filename): + + import AstCreator.* + + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + private val unitToAsts = mutable.HashMap[soot.Unit, Seq[Ast]]() + private val controlTargets = mutable.HashMap[Seq[Ast], soot.Unit]() + // There are many, but the popular ones should do https://en.wikipedia.org/wiki/List_of_JVM_languages + private val JVM_LANGS = HashSet("scala", "clojure", "groovy", "kotlin", "jython", "jruby") + + /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates + * TYPE nodes for each key in the map. + */ + private def registerType(typeName: String): String = + global.usedTypes.put(typeName, true) + typeName + + /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a + * DiffGraph containing the corresponding CPG AST. + */ + def createAst(): DiffGraphBuilder = + val astRoot = astForCompilationUnit(cls) + storeInDiffGraph(astRoot, diffGraph) + diffGraph + + /** Translate compilation unit into AST + */ + private def astForCompilationUnit(cls: SootClass): Ast = + val ast = astForPackageDeclaration(cls.getPackageName) + val namespaceBlockFullName = + ast.root.collect { case x: NewNamespaceBlock => x.fullName }.getOrElse("none") + ast.withChild(astForTypeDecl(cls.getType, namespaceBlockFullName)) + + /** Translate package declaration into AST consisting of a corresponding namespace block. + */ + private def astForPackageDeclaration(packageDecl: String): Ast = + val absolutePath = new java.io.File(filename).toPath.toAbsolutePath.normalize().toString + val name = packageDecl.split("\\.").lastOption.getOrElse("") + val namespaceBlock = NewNamespaceBlock() + .name(name) + .fullName(packageDecl) + Ast(namespaceBlock.filename(absolutePath).order(1)) + + /** Creates a list of all inherited classes and implemented interfaces. If there are none then a + * list with a single element 'java.lang.Object' is returned by default. Returns two lists in + * the form of (List[Super Classes], List[Interfaces]). + */ + private def inheritedAndImplementedClasses(clazz: SootClass): (List[String], List[String]) = + val implementsTypeFullName = clazz.getInterfaces.asScala.map { (i: SootClass) => + registerType(i.getType.toQuotedString) + }.toList + val inheritsFromTypeFullName = + if clazz.hasSuperclass && clazz.getSuperclass.getType.toQuotedString != "java.lang.Object" + then + List(registerType(clazz.getSuperclass.getType.toQuotedString)) + else if implementsTypeFullName.isEmpty then + List(registerType("java.lang.Object")) + else List() + + (inheritsFromTypeFullName, implementsTypeFullName) + + /** Creates the AST root for type declarations and acts as the entry point for method + * generation. + */ + private def astForTypeDecl(typ: RefType, namespaceBlockFullName: String): Ast = + val fullName = registerType(typ.toQuotedString) + val shortName = typ.getSootClass.getShortJavaStyleName + val clz = typ.getSootClass + val code = new mutable.StringBuilder() + + if clz.isPublic then code.append("public ") + else if clz.isPrivate then code.append("private ") + if clz.isStatic then code.append("static ") + if clz.isFinal then code.append("final ") + if clz.isInterface then code.append("interface ") + else if clz.isAbstract then code.append("abstract ") + if clz.isEnum then code.append("enum ") + if !clz.isInterface then code.append(s"class $shortName") + else code.append(shortName) + + val modifiers = astsForModifiers(clz) + val (inherited, implemented) = inheritedAndImplementedClasses(typ.getSootClass) + + if inherited.nonEmpty then code.append(s" extends ${inherited.mkString(", ")}") + if implemented.nonEmpty then code.append(s" implements ${implemented.mkString(", ")}") + + val typeDecl = NewTypeDecl() + .name(shortName) + .fullName(fullName) + .order(1) // Jimple always has 1 class per file + .filename(filename) + .code(code.toString()) + .inheritsFromTypeFullName(inherited ++ implemented) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(namespaceBlockFullName) + val methodAsts = withOrder(typ.getSootClass.getMethods.asScala.toList.sortWith((x, y) => + x.getName > y.getName + )) { + (m, order) => + astForMethod(m, typ, order) } - val parameterAsts = - Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ withOrder(methodBody.getParameterLocals) { - (p, order) => astForParameter(p, order, methodDeclaration, parameterAnnotations) - } + + val memberAsts = typ.getSootClass.getFields.asScala + .filter(_.isDeclared) + .zipWithIndex + .map { case (v, i) => + astForField(v, i + methodAsts.size + 1) + } + .toList + + Ast(typeDecl) + .withChildren(astsForHostTags(clz)) + .withChildren(memberAsts) + .withChildren(methodAsts) + .withChildren(modifiers) + end astForTypeDecl + + private def astForField(field: SootField, order: Int): Ast = + val typeFullName = registerType(field.getType.toQuotedString) + val name = field.getName + val code = if field.getDeclaration.contains("enum") then name else s"$typeFullName $name" + val annotations = field.getTags.asScala + .collect { case x: VisibilityAnnotationTag => x } + .flatMap(_.getAnnotations.asScala) + Ast( - methodNode - .lineNumberEnd(methodBody.toString.split('\n').filterNot(_.isBlank).length) - .code(methodBody.toString) + NewMember() + .name(name) + .lineNumber(line(field)) + .columnNumber(column(field)) + .typeFullName(typeFullName) + .order(order) + .code(code) + ).withChildren(withOrder(annotations) { (a, aOrder) => + astsForAnnotations(a, aOrder, field) + }) + end astForField + + private def astForMethod(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int): Ast = + val methodNode = createMethodNode(methodDeclaration, typeDecl, childNum) + try + if !methodDeclaration.isConcrete then + // Soot is not able to parse origin parameter names of abstract methods + // https://github.com/soot-oss/soot/issues/1517 + val locals = methodDeclaration.getParameterTypes.asScala.zipWithIndex + .map { case (typ, index) => new JimpleLocal(s"param${index + 1}", typ) } + val parameterAsts = + Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ + withOrder(locals) { (p, order) => + astForParameter(p, order, methodDeclaration, Map()) + } + Ast(methodNode) + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(parameterAsts) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(Ast(NewBlock())) + .withChild(astForMethodReturn(methodDeclaration)) + else + val lastOrder = 2 + methodDeclaration.getParameterCount + // Map params to their annotations + val mTags = methodDeclaration.getTags.asScala + val paramAnnos = + mTags.collect { case x: VisibilityParameterAnnotationTag => x }.flatMap( + _.getVisibilityAnnotations.asScala + ) + val paramNames = + mTags.collect { case x: ParamNamesTag => x }.flatMap(_.getNames.asScala) + val parameterAnnotations = paramNames.zip(paramAnnos).filter(_._2 != null).toMap + val methodBody = Try(methodDeclaration.getActiveBody) match + case Failure(_) => methodDeclaration.retrieveActiveBody() + case Success(body) => body + val parameterAsts = + Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ withOrder( + methodBody.getParameterLocals + ) { + (p, order) => + astForParameter(p, order, methodDeclaration, parameterAnnotations) + } + Ast( + methodNode + .lineNumberEnd(methodBody.toString.split('\n').filterNot(_.isBlank).length) + .code(methodBody.toString) + ) + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(parameterAsts) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(astForMethodBody(methodBody, lastOrder)) + .withChild(astForMethodReturn(methodDeclaration)) + catch + case e: RuntimeException => + // Use a few heuristics to determine if this is not built with the JDK + val nonJavaLibs = + cls.getInterfaces.asScala.map(_.getPackageName).filter(JVM_LANGS.contains).toSet + if nonJavaLibs.nonEmpty || cls.getMethods.asScala.exists(_.getName.endsWith("$")) + then + val errMsg = + "The bytecode for this method suggests it is built with a non-Java JVM language. " + + "Soot requires including the specific language's SDK in the analysis to create the method body for " + + s"'${methodNode.fullName}' correctly." + logger.warn( + if nonJavaLibs.nonEmpty then + s"$errMsg. Language(s) detected: ${nonJavaLibs.mkString(",")}." + else errMsg + ) + else + logger.warn( + s"Unexpected runtime exception while parsing method body! Will stub the method '${methodNode.fullName}''", + e + ) + Ast(methodNode) + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(astForMethodReturn(methodDeclaration)) + finally + // Join all targets with CFG edges - this seems to work from what is seen on DotFiles + controlTargets.foreach({ case (asts, units) => + asts.headOption match + case Some(value) => + diffGraph.addEdge( + value.root.get, + unitToAsts(units).last.root.get, + EdgeTypes.CFG + ) + case None => + }) + // Clear these maps + controlTargets.clear() + unitToAsts.clear() + end try + end astForMethod + + private def getEvaluationStrategy(typ: soot.Type): String = + typ match + case _: PrimType => EvaluationStrategies.BY_VALUE + case _: VoidType => EvaluationStrategies.BY_VALUE + case _: NullType => EvaluationStrategies.BY_VALUE + case _: RefLikeType => EvaluationStrategies.BY_REFERENCE + case _ => EvaluationStrategies.BY_SHARING + + private def astForParameter( + parameter: soot.Local, + childNum: Int, + methodDeclaration: SootMethod, + parameterAnnotations: Map[String, VisibilityAnnotationTag] + ): Ast = + val typeFullName = registerType(parameter.getType.toQuotedString) + + val parameterNode = Ast( + NewMethodParameterIn() + .name(parameter.getName) + .code(s"$typeFullName ${parameter.getName}") + .typeFullName(typeFullName) + .order(childNum) + .index(childNum) + .lineNumber(line(methodDeclaration)) + .columnNumber(column(methodDeclaration)) + .evaluationStrategy(getEvaluationStrategy(parameter.getType)) ) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(parameterAsts) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(astForMethodBody(methodBody, lastOrder)) - .withChild(astForMethodReturn(methodDeclaration)) - } - } catch { - case e: RuntimeException => - // Use a few heuristics to determine if this is not built with the JDK - val nonJavaLibs = cls.getInterfaces.asScala.map(_.getPackageName).filter(JVM_LANGS.contains).toSet - if (nonJavaLibs.nonEmpty || cls.getMethods.asScala.exists(_.getName.endsWith("$"))) { - val errMsg = "The bytecode for this method suggests it is built with a non-Java JVM language. " + - "Soot requires including the specific language's SDK in the analysis to create the method body for " + - s"'${methodNode.fullName}' correctly." - logger.warn( - if (nonJavaLibs.nonEmpty) s"$errMsg. Language(s) detected: ${nonJavaLibs.mkString(",")}." - else errMsg - ) - } else { - logger.warn( - s"Unexpected runtime exception while parsing method body! Will stub the method '${methodNode.fullName}''", - e - ) + + parameterAnnotations.get(parameter.getName) match + case Some(annoRoot) => + parameterNode.withChildren(withOrder(annoRoot.getAnnotations.asScala) { + (a, order) => + astsForAnnotations(a, order, methodDeclaration) + }) + case None => parameterNode + end astForParameter + + private def astsForHostTags(host: AbstractHost): Seq[Ast] = + host.getTags.asScala + .collect { case x: VisibilityAnnotationTag => x } + .flatMap { x => + withOrder(x.getAnnotations.asScala) { (a, order) => + astsForAnnotations(a, order, host) + } + } + .toSeq + + private def astsForAnnotations(annotation: AnnotationTag, order: Int, host: AbstractHost): Ast = + val annoType = registerType(annotation.getType.parseAsJavaType) + val name = annoType.split('.').last + val elementNodes = withOrder(annotation.getElems.asScala) { case (a, order) => + astForAnnotationElement(a, order, host) } - Ast(methodNode) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(astForMethodReturn(methodDeclaration)) - } finally { - // Join all targets with CFG edges - this seems to work from what is seen on DotFiles - controlTargets.foreach({ case (asts, units) => - asts.headOption match { - case Some(value) => - diffGraph.addEdge(value.root.get, unitToAsts(units).last.root.get, EdgeTypes.CFG) - case None => + val annotationNode = NewAnnotation() + .name(name) + .code( + s"@$name(${elementNodes.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")})" + ) + .fullName(annoType) + .order(order) + Ast(annotationNode) + .withChildren(elementNodes) + + private def astForAnnotationElement( + annoElement: AnnotationElem, + order: Int, + parent: AbstractHost + ): Ast = + def getLiteralElementNameAndCode(annoElement: AnnotationElem): (String, String) = + annoElement match + case x: AnnotationClassElem => + val desc = registerType(x.getDesc.parseAsJavaType) + (desc, desc) + case x: AnnotationBooleanElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationDoubleElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationEnumElem => (x.getConstantName, x.getConstantName) + case x: AnnotationFloatElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationIntElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationLongElem => (x.getValue.toString, x.getValue.toString) + case _ => ("", "") + val lineNo = line(parent) + val columnNo = column(parent) + val codeBuilder = new mutable.StringBuilder() + val astChildren = ListBuffer.empty[Ast] + if annoElement.getName != null then + astChildren.append( + Ast( + NewAnnotationParameter() + .code(annoElement.getName) + .lineNumber(lineNo) + .columnNumber(columnNo) + .order(1) + ) + ) + codeBuilder.append(s"${annoElement.getName} = ") + astChildren.append(annoElement match + case x: AnnotationAnnotationElem => + val rhsAst = astsForAnnotations(x.getValue, astChildren.size + 1, parent) + codeBuilder.append( + s"${rhsAst.root.flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}" + ) + rhsAst + case x: AnnotationArrayElem => + val (rhsAst, code) = astForAnnotationArrayElement(x, astChildren.size + 1, parent) + codeBuilder.append(code) + rhsAst + case x => + val (name, code) = x match + case y: AnnotationStringElem => (y.getValue, s"\"${y.getValue}\"") + case _ => getLiteralElementNameAndCode(x) + val rhsOrder = if annoElement.getName == null then order else astChildren.size + 1 + codeBuilder.append(code) + Ast(NewAnnotationLiteral().name(name).code(code).order(rhsOrder).argumentIndex( + rhsOrder + )) + ) + + if astChildren.size == 1 then + astChildren.head + else + val paramAssign = NewAnnotationParameterAssign() + .code(codeBuilder.toString) + .lineNumber(lineNo) + .columnNumber(columnNo) + .order(order) + + Ast(paramAssign) + .withChildren(astChildren) + end astForAnnotationElement + + private def astForAnnotationArrayElement( + x: AnnotationArrayElem, + order: Int, + parent: AbstractHost + ): (Ast, String) = + val elems = withOrder(x.getValues.asScala) { (elem, order) => + astForAnnotationElement(elem, order, parent) } - }) - // Clear these maps - controlTargets.clear() - unitToAsts.clear() - } - } - - private def getEvaluationStrategy(typ: soot.Type): String = - typ match { - case _: PrimType => EvaluationStrategies.BY_VALUE - case _: VoidType => EvaluationStrategies.BY_VALUE - case _: NullType => EvaluationStrategies.BY_VALUE - case _: RefLikeType => EvaluationStrategies.BY_REFERENCE - case _ => EvaluationStrategies.BY_SHARING - } - - private def astForParameter( - parameter: soot.Local, - childNum: Int, - methodDeclaration: SootMethod, - parameterAnnotations: Map[String, VisibilityAnnotationTag] - ): Ast = { - val typeFullName = registerType(parameter.getType.toQuotedString) - - val parameterNode = Ast( - NewMethodParameterIn() - .name(parameter.getName) - .code(s"$typeFullName ${parameter.getName}") - .typeFullName(typeFullName) - .order(childNum) - .index(childNum) - .lineNumber(line(methodDeclaration)) - .columnNumber(column(methodDeclaration)) - .evaluationStrategy(getEvaluationStrategy(parameter.getType)) - ) - - parameterAnnotations.get(parameter.getName) match { - case Some(annoRoot) => - parameterNode.withChildren(withOrder(annoRoot.getAnnotations.asScala) { (a, order) => - astsForAnnotations(a, order, methodDeclaration) - }) - case None => parameterNode - } - } - - private def astsForHostTags(host: AbstractHost): Seq[Ast] = { - host.getTags.asScala - .collect { case x: VisibilityAnnotationTag => x } - .flatMap { x => - withOrder(x.getAnnotations.asScala) { (a, order) => astsForAnnotations(a, order, host) } - } - .toSeq - } - - private def astsForAnnotations(annotation: AnnotationTag, order: Int, host: AbstractHost): Ast = { - val annoType = registerType(annotation.getType.parseAsJavaType) - val name = annoType.split('.').last - val elementNodes = withOrder(annotation.getElems.asScala) { case (a, order) => - astForAnnotationElement(a, order, host) - } - val annotationNode = NewAnnotation() - .name(name) - .code(s"@$name(${elementNodes.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")})") - .fullName(annoType) - .order(order) - Ast(annotationNode) - .withChildren(elementNodes) - } - - private def astForAnnotationElement(annoElement: AnnotationElem, order: Int, parent: AbstractHost): Ast = { - def getLiteralElementNameAndCode(annoElement: AnnotationElem): (String, String) = annoElement match { - case x: AnnotationClassElem => - val desc = registerType(x.getDesc.parseAsJavaType) - (desc, desc) - case x: AnnotationBooleanElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationDoubleElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationEnumElem => (x.getConstantName, x.getConstantName) - case x: AnnotationFloatElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationIntElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationLongElem => (x.getValue.toString, x.getValue.toString) - case _ => ("", "") - } - val lineNo = line(parent) - val columnNo = column(parent) - val codeBuilder = new mutable.StringBuilder() - val astChildren = ListBuffer.empty[Ast] - if (annoElement.getName != null) { - astChildren.append( + val code = + s"{${elems.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}}" + val array = NewArrayInitializer().code(code).order(order).argumentIndex(order) + (Ast(array).withChildren(elems), code) + + private def astForMethodBody(body: Body, order: Int): Ast = + val block = NewBlock().order(order).lineNumber(line(body)).columnNumber(column(body)) + val jimpleParams = body.getParameterLocals.asScala.toList + // Don't let parameters also become locals (avoiding duplication) + val jimpleLocals = body.getLocals.asScala.filterNot(l => + jimpleParams.contains(l) || l.getName == "this" + ).toList + val locals = withOrder(jimpleLocals) { case (l, order) => + val name = l.getName + val typeFullName = registerType(l.getType.toQuotedString) + val code = s"$typeFullName $name" + Ast(NewLocal().name(name).code(code).typeFullName(typeFullName).order(order)) + } + Ast(block) + .withChildren(locals) + .withChildren(withOrder(body.getUnits.asScala.filterNot(isIgnoredUnit)) { (x, order) => + astsForStatement(x, order + locals.size) + }.flatten) + + private def isIgnoredUnit(unit: soot.Unit): Boolean = + unit match + case _: IdentityStmt => true + case _: NopStmt => true + case _ => false + + private def astsForStatement(statement: soot.Unit, order: Int): Seq[Ast] = + val stmt = statement match + case x: AssignStmt => astsForDefinition(x, order) + case x: InvokeStmt => astsForExpression(x.getInvokeExpr, order, statement) + case x: ReturnStmt => astsForReturnNode(x, order) + case x: ReturnVoidStmt => astsForReturnVoidNode(x, order) + case x: IfStmt => astsForIfStmt(x, order) + case x: GotoStmt => astsForGotoStmt(x, order) + case x: LookupSwitchStmt => astsForLookupSwitchStmt(x, order) + case x: TableSwitchStmt => astsForTableSwitchStmt(x, order) + case x: ThrowStmt => astsForThrowStmt(x, order) + case x: MonitorStmt => astsForMonitorStmt(x, order) + case _: IdentityStmt => Seq() // Identity statements redefine parameters as locals + case _: NopStmt => Seq() // Ignore NOP statements + case x => + logger.warn(s"Unhandled soot.Unit type ${x.getClass}") + Seq(astForUnknownStmt(x, None, order)) + unitToAsts.put(statement, stmt) + stmt + end astsForStatement + + private def astForBinOpExpr(binOp: BinopExpr, order: Int, parentUnit: soot.Unit): Ast = + // https://javadoc.io/static/org.soot-oss/soot/4.3.0/soot/jimple/BinopExpr.html + val operatorName = binOp match + case _: AddExpr => Operators.addition + case _: SubExpr => Operators.subtraction + case _: MulExpr => Operators.multiplication + case _: DivExpr => Operators.division + case _: RemExpr => Operators.modulo + case _: GeExpr => Operators.greaterEqualsThan + case _: GtExpr => Operators.greaterThan + case _: LeExpr => Operators.lessEqualsThan + case _: LtExpr => Operators.lessThan + case _: ShlExpr => Operators.shiftLeft + case _: ShrExpr => Operators.logicalShiftRight + case _: UshrExpr => Operators.arithmeticShiftRight + case _: CmpExpr => Operators.compare + case _: CmpgExpr => Operators.compare + case _: CmplExpr => Operators.compare + case _: AndExpr => Operators.and + case _: OrExpr => Operators.or + case _: XorExpr => Operators.xor + case _: EqExpr => Operators.equals + case _: NeExpr => Operators.notEquals + case _ => + logger.warn( + s"Unhandled binary operator ${binOp.getSymbol} (${binOp.getClass}). This is unexpected behaviour." + ) + ".unknown" + + val callNode = NewCall() + .name(operatorName) + .methodFullName(operatorName) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(binOp.toString) + .argumentIndex(order) + .order(order) + + val args = + astsForValue(binOp.getOp1, 1, parentUnit) ++ astsForValue(binOp.getOp2, 2, parentUnit) + callAst(callNode, args) + end astForBinOpExpr + + private def astsForExpression(expr: Expr, order: Int, parentUnit: soot.Unit): Seq[Ast] = + expr match + case x: BinopExpr => Seq(astForBinOpExpr(x, order, parentUnit)) + case x: InvokeExpr => Seq(astForInvokeExpr(x, order, parentUnit)) + case x: AnyNewExpr => Seq(astForNewExpr(x, order, parentUnit)) + case x: CastExpr => Seq(astForUnaryExpr(Operators.cast, x, x.getOp, order, parentUnit)) + case x: InstanceOfExpr => + Seq(astForUnaryExpr(Operators.instanceOf, x, x.getOp, order, parentUnit)) + case x: LengthExpr => + Seq(astForUnaryExpr(Operators.lengthOf, x, x.getOp, order, parentUnit)) + case x: NegExpr => Seq(astForUnaryExpr(Operators.minus, x, x.getOp, order, parentUnit)) + case x => + logger.warn(s"Unhandled soot.Expr type ${x.getClass}") + Seq() + + private def astsForValue(value: soot.Value, order: Int, parentUnit: soot.Unit): Seq[Ast] = + value match + case x: Expr => astsForExpression(x, order, parentUnit) + case x: soot.Local => Seq(astForLocal(x, order, parentUnit)) + case x: CaughtExceptionRef => Seq(astForCaughtExceptionRef(x, order, parentUnit)) + case x: Constant => Seq(astForConstantExpr(x, order)) + case x: FieldRef => Seq(astForFieldRef(x, order, parentUnit)) + case x: ThisRef => Seq(createThisNode(x)) + case x: ParameterRef => Seq(createParameterNode(x, order)) + case x: IdentityRef => Seq(astForIdentityRef(x, order, parentUnit)) + case x: ArrayRef => Seq(astForArrayRef(x, order, parentUnit)) + case x => + logger.warn(s"Unhandled soot.Value type ${x.getClass}") + Seq() + + private def astForArrayRef(arrRef: ArrayRef, order: Int, parentUnit: soot.Unit): Ast = + val indexAccess = NewCall() + .name(Operators.indexAccess) + .methodFullName(Operators.indexAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(arrRef.toString()) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .typeFullName(registerType(arrRef.getType.toQuotedString)) + + val astChildren = astsForValue(arrRef.getBase, 1, parentUnit) ++ astsForValue( + arrRef.getIndex, + 2, + parentUnit + ) + Ast(indexAccess) + .withChildren(astChildren) + .withArgEdges(indexAccess, astChildren.flatMap(_.root)) + end astForArrayRef + + private def astForLocal(local: soot.Local, order: Int, parentUnit: soot.Unit): Ast = + val name = local.getName + val typeFullName = registerType(local.getType.toQuotedString) Ast( - NewAnnotationParameter() - .code(annoElement.getName) - .lineNumber(lineNo) - .columnNumber(columnNo) - .order(1) + NewIdentifier() + .name(name) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .order(order) + .argumentIndex(order) + .code(name) + .typeFullName(typeFullName) ) - ) - codeBuilder.append(s"${annoElement.getName} = ") - } - astChildren.append(annoElement match { - case x: AnnotationAnnotationElem => - val rhsAst = astsForAnnotations(x.getValue, astChildren.size + 1, parent) - codeBuilder.append(s"${rhsAst.root.flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}") - rhsAst - case x: AnnotationArrayElem => - val (rhsAst, code) = astForAnnotationArrayElement(x, astChildren.size + 1, parent) - codeBuilder.append(code) - rhsAst - case x => - val (name, code) = x match { - case y: AnnotationStringElem => (y.getValue, s"\"${y.getValue}\"") - case _ => getLiteralElementNameAndCode(x) - } - val rhsOrder = if (annoElement.getName == null) order else astChildren.size + 1 - codeBuilder.append(code) - Ast(NewAnnotationLiteral().name(name).code(code).order(rhsOrder).argumentIndex(rhsOrder)) - }) - - if (astChildren.size == 1) { - astChildren.head - } else { - val paramAssign = NewAnnotationParameterAssign() - .code(codeBuilder.toString) - .lineNumber(lineNo) - .columnNumber(columnNo) - .order(order) - - Ast(paramAssign) - .withChildren(astChildren) - } - } - - private def astForAnnotationArrayElement(x: AnnotationArrayElem, order: Int, parent: AbstractHost): (Ast, String) = { - val elems = withOrder(x.getValues.asScala) { (elem, order) => astForAnnotationElement(elem, order, parent) } - val code = s"{${elems.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}}" - val array = NewArrayInitializer().code(code).order(order).argumentIndex(order) - (Ast(array).withChildren(elems), code) - } - - private def astForMethodBody(body: Body, order: Int): Ast = { - val block = NewBlock().order(order).lineNumber(line(body)).columnNumber(column(body)) - val jimpleParams = body.getParameterLocals.asScala.toList - // Don't let parameters also become locals (avoiding duplication) - val jimpleLocals = body.getLocals.asScala.filterNot(l => jimpleParams.contains(l) || l.getName == "this").toList - val locals = withOrder(jimpleLocals) { case (l, order) => - val name = l.getName - val typeFullName = registerType(l.getType.toQuotedString) - val code = s"$typeFullName $name" - Ast(NewLocal().name(name).code(code).typeFullName(typeFullName).order(order)) - } - Ast(block) - .withChildren(locals) - .withChildren(withOrder(body.getUnits.asScala.filterNot(isIgnoredUnit)) { (x, order) => - astsForStatement(x, order + locals.size) - }.flatten) - } - - private def isIgnoredUnit(unit: soot.Unit): Boolean = { - unit match { - case _: IdentityStmt => true - case _: NopStmt => true - case _ => false - } - } - - private def astsForStatement(statement: soot.Unit, order: Int): Seq[Ast] = { - val stmt = statement match { - case x: AssignStmt => astsForDefinition(x, order) - case x: InvokeStmt => astsForExpression(x.getInvokeExpr, order, statement) - case x: ReturnStmt => astsForReturnNode(x, order) - case x: ReturnVoidStmt => astsForReturnVoidNode(x, order) - case x: IfStmt => astsForIfStmt(x, order) - case x: GotoStmt => astsForGotoStmt(x, order) - case x: LookupSwitchStmt => astsForLookupSwitchStmt(x, order) - case x: TableSwitchStmt => astsForTableSwitchStmt(x, order) - case x: ThrowStmt => astsForThrowStmt(x, order) - case x: MonitorStmt => astsForMonitorStmt(x, order) - case _: IdentityStmt => Seq() // Identity statements redefine parameters as locals - case _: NopStmt => Seq() // Ignore NOP statements - case x => - logger.warn(s"Unhandled soot.Unit type ${x.getClass}") - Seq(astForUnknownStmt(x, None, order)) - } - unitToAsts.put(statement, stmt) - stmt - } - - private def astForBinOpExpr(binOp: BinopExpr, order: Int, parentUnit: soot.Unit): Ast = { - // https://javadoc.io/static/org.soot-oss/soot/4.3.0/soot/jimple/BinopExpr.html - val operatorName = binOp match { - case _: AddExpr => Operators.addition - case _: SubExpr => Operators.subtraction - case _: MulExpr => Operators.multiplication - case _: DivExpr => Operators.division - case _: RemExpr => Operators.modulo - case _: GeExpr => Operators.greaterEqualsThan - case _: GtExpr => Operators.greaterThan - case _: LeExpr => Operators.lessEqualsThan - case _: LtExpr => Operators.lessThan - case _: ShlExpr => Operators.shiftLeft - case _: ShrExpr => Operators.logicalShiftRight - case _: UshrExpr => Operators.arithmeticShiftRight - case _: CmpExpr => Operators.compare - case _: CmpgExpr => Operators.compare - case _: CmplExpr => Operators.compare - case _: AndExpr => Operators.and - case _: OrExpr => Operators.or - case _: XorExpr => Operators.xor - case _: EqExpr => Operators.equals - case _: NeExpr => Operators.notEquals - case _ => - logger.warn(s"Unhandled binary operator ${binOp.getSymbol} (${binOp.getClass}). This is unexpected behaviour.") - ".unknown" - } - - val callNode = NewCall() - .name(operatorName) - .methodFullName(operatorName) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(binOp.toString) - .argumentIndex(order) - .order(order) - - val args = - astsForValue(binOp.getOp1, 1, parentUnit) ++ astsForValue(binOp.getOp2, 2, parentUnit) - callAst(callNode, args) - } - - private def astsForExpression(expr: Expr, order: Int, parentUnit: soot.Unit): Seq[Ast] = { - expr match { - case x: BinopExpr => Seq(astForBinOpExpr(x, order, parentUnit)) - case x: InvokeExpr => Seq(astForInvokeExpr(x, order, parentUnit)) - case x: AnyNewExpr => Seq(astForNewExpr(x, order, parentUnit)) - case x: CastExpr => Seq(astForUnaryExpr(Operators.cast, x, x.getOp, order, parentUnit)) - case x: InstanceOfExpr => - Seq(astForUnaryExpr(Operators.instanceOf, x, x.getOp, order, parentUnit)) - case x: LengthExpr => - Seq(astForUnaryExpr(Operators.lengthOf, x, x.getOp, order, parentUnit)) - case x: NegExpr => Seq(astForUnaryExpr(Operators.minus, x, x.getOp, order, parentUnit)) - case x => - logger.warn(s"Unhandled soot.Expr type ${x.getClass}") - Seq() - } - } - - private def astsForValue(value: soot.Value, order: Int, parentUnit: soot.Unit): Seq[Ast] = { - value match { - case x: Expr => astsForExpression(x, order, parentUnit) - case x: soot.Local => Seq(astForLocal(x, order, parentUnit)) - case x: CaughtExceptionRef => Seq(astForCaughtExceptionRef(x, order, parentUnit)) - case x: Constant => Seq(astForConstantExpr(x, order)) - case x: FieldRef => Seq(astForFieldRef(x, order, parentUnit)) - case x: ThisRef => Seq(createThisNode(x)) - case x: ParameterRef => Seq(createParameterNode(x, order)) - case x: IdentityRef => Seq(astForIdentityRef(x, order, parentUnit)) - case x: ArrayRef => Seq(astForArrayRef(x, order, parentUnit)) - case x => - logger.warn(s"Unhandled soot.Value type ${x.getClass}") - Seq() - } - } - - private def astForArrayRef(arrRef: ArrayRef, order: Int, parentUnit: soot.Unit): Ast = { - val indexAccess = NewCall() - .name(Operators.indexAccess) - .methodFullName(Operators.indexAccess) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(arrRef.toString()) - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .typeFullName(registerType(arrRef.getType.toQuotedString)) - - val astChildren = astsForValue(arrRef.getBase, 1, parentUnit) ++ astsForValue(arrRef.getIndex, 2, parentUnit) - Ast(indexAccess) - .withChildren(astChildren) - .withArgEdges(indexAccess, astChildren.flatMap(_.root)) - } - - private def astForLocal(local: soot.Local, order: Int, parentUnit: soot.Unit): Ast = { - val name = local.getName - val typeFullName = registerType(local.getType.toQuotedString) - Ast( - NewIdentifier() - .name(name) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .order(order) - .argumentIndex(order) - .code(name) - .typeFullName(typeFullName) - ) - } - - private def astForIdentityRef(x: IdentityRef, order: Int, parentUnit: soot.Unit): Ast = { - Ast( - NewIdentifier() - .code(x.toString()) - .name(x.toString()) - .order(order) - .argumentIndex(order) - .typeFullName(registerType(x.getType.toQuotedString)) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - ) - } - - private def astForInvokeExpr(invokeExpr: InvokeExpr, order: Int, parentUnit: soot.Unit): Ast = { - val callee = invokeExpr.getMethodRef - val dispatchType = invokeExpr match { - case _ if callee.isConstructor => DispatchTypes.STATIC_DISPATCH - case _: DynamicInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH - case _: InstanceInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH - case _ => DispatchTypes.STATIC_DISPATCH - } - - val signature = - s"${registerType(callee.getReturnType.toQuotedString)}(${(for (i <- 0 until callee.getParameterTypes.size()) - yield registerType(callee.getParameterType(i).toQuotedString)).mkString(",")})" - val thisAsts = invokeExpr match { - case expr: InstanceInvokeExpr => astsForValue(expr.getBase, 0, parentUnit) - case _ => Seq(createThisNode(callee, NewIdentifier())) - } - - val methodName = - if (callee.isConstructor) - registerType(callee.getDeclaringClass.getType.getClassName) - else - callee.getName - - val calleeType = registerType(callee.getDeclaringClass.getType.toQuotedString) - val callType = - if (callee.isConstructor) "void" - else calleeType - - val code = invokeExpr match { - case expr: InstanceInvokeExpr => - s"${expr.getBase}.$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" - case _ => s"$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" - } - - val callNode = NewCall() - .name(callee.getName) - .code(code) - .dispatchType(dispatchType) - .order(order) - .argumentIndex(order) - .methodFullName(s"$calleeType.${callee.getName}:$signature") - .signature(signature) - .typeFullName(callType) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - - val argAsts = withOrder(invokeExpr match { - case x: DynamicInvokeExpr => x.getArgs.asScala ++ x.getBootstrapArgs.asScala - case x => x.getArgs.asScala - }) { case (arg, order) => - astsForValue(arg, order, parentUnit) - }.flatten - - val callAst = Ast(callNode) - .withChildren(thisAsts) - .withChildren(argAsts) - .withArgEdges(callNode, thisAsts.flatMap(_.root)) - .withArgEdges(callNode, argAsts.flatMap(_.root)) - - thisAsts.flatMap(_.root).headOption match { - case Some(thisAst) => callAst.withReceiverEdge(callNode, thisAst) - case None => callAst - } - } - - private def astForNewExpr(x: AnyNewExpr, order: Int, parentUnit: soot.Unit): Ast = { - x match { - case u: NewArrayExpr => - astForArrayCreateExpr(x, List(u.getSize), order, parentUnit) - case u: NewMultiArrayExpr => - astForArrayCreateExpr(x, u.getSizes.asScala, order, parentUnit) - case _ => - val parentType = registerType(x.getType.toQuotedString) + + private def astForIdentityRef(x: IdentityRef, order: Int, parentUnit: soot.Unit): Ast = Ast( - NewCall() + NewIdentifier() + .code(x.toString()) + .name(x.toString()) + .order(order) + .argumentIndex(order) + .typeFullName(registerType(x.getType.toQuotedString)) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + ) + + private def astForInvokeExpr(invokeExpr: InvokeExpr, order: Int, parentUnit: soot.Unit): Ast = + val callee = invokeExpr.getMethodRef + val dispatchType = invokeExpr match + case _ if callee.isConstructor => DispatchTypes.STATIC_DISPATCH + case _: DynamicInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH + case _: InstanceInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH + case _ => DispatchTypes.STATIC_DISPATCH + + val signature = + s"${registerType(callee.getReturnType.toQuotedString)}(${(for (i <- 0 until callee.getParameterTypes.size()) + yield registerType(callee.getParameterType(i).toQuotedString)).mkString(",")})" + val thisAsts = invokeExpr match + case expr: InstanceInvokeExpr => astsForValue(expr.getBase, 0, parentUnit) + case _ => Seq(createThisNode(callee, NewIdentifier())) + + val methodName = + if callee.isConstructor then + registerType(callee.getDeclaringClass.getType.getClassName) + else + callee.getName + + val calleeType = registerType(callee.getDeclaringClass.getType.toQuotedString) + val callType = + if callee.isConstructor then "void" + else calleeType + + val code = invokeExpr match + case expr: InstanceInvokeExpr => + s"${expr.getBase}.$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" + case _ => s"$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" + + val callNode = NewCall() + .name(callee.getName) + .code(code) + .dispatchType(dispatchType) + .order(order) + .argumentIndex(order) + .methodFullName(s"$calleeType.${callee.getName}:$signature") + .signature(signature) + .typeFullName(callType) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + + val argAsts = withOrder(invokeExpr match + case x: DynamicInvokeExpr => x.getArgs.asScala ++ x.getBootstrapArgs.asScala + case x => x.getArgs.asScala + ) { case (arg, order) => + astsForValue(arg, order, parentUnit) + }.flatten + + val callAst = Ast(callNode) + .withChildren(thisAsts) + .withChildren(argAsts) + .withArgEdges(callNode, thisAsts.flatMap(_.root)) + .withArgEdges(callNode, argAsts.flatMap(_.root)) + + thisAsts.flatMap(_.root).headOption match + case Some(thisAst) => callAst.withReceiverEdge(callNode, thisAst) + case None => callAst + end astForInvokeExpr + + private def astForNewExpr(x: AnyNewExpr, order: Int, parentUnit: soot.Unit): Ast = + x match + case u: NewArrayExpr => + astForArrayCreateExpr(x, List(u.getSize), order, parentUnit) + case u: NewMultiArrayExpr => + astForArrayCreateExpr(x, u.getSizes.asScala, order, parentUnit) + case _ => + val parentType = registerType(x.getType.toQuotedString) + Ast( + NewCall() + .name(Operators.alloc) + .methodFullName(Operators.alloc) + .typeFullName(parentType) + .code(s"new ${x.getType}") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + ) + + private def astForArrayCreateExpr( + arrayInitExpr: Expr, + sizes: Iterable[Value], + order: Int, + parentUnit: soot.Unit + ): Ast = + // Jimple does not have Operators.arrayInitializer + // to enforce 3 address code form + val arrayBaseType = registerType(arrayInitExpr.getType.toQuotedString) + val code = + s"new ${arrayBaseType.substring(0, arrayBaseType.indexOf('['))}${sizes.map(s => s"[$s]").mkString}" + val callBlock = NewCall() .name(Operators.alloc) .methodFullName(Operators.alloc) - .typeFullName(parentType) - .code(s"new ${x.getType}") + .code(code) .dispatchType(DispatchTypes.STATIC_DISPATCH) .order(order) + .typeFullName(arrayBaseType) .argumentIndex(order) .lineNumber(line(parentUnit)) .columnNumber(column(parentUnit)) - ) - } - } - - private def astForArrayCreateExpr( - arrayInitExpr: Expr, - sizes: Iterable[Value], - order: Int, - parentUnit: soot.Unit - ): Ast = { - // Jimple does not have Operators.arrayInitializer - // to enforce 3 address code form - val arrayBaseType = registerType(arrayInitExpr.getType.toQuotedString) - val code = s"new ${arrayBaseType.substring(0, arrayBaseType.indexOf('['))}${sizes.map(s => s"[$s]").mkString}" - val callBlock = NewCall() - .name(Operators.alloc) - .methodFullName(Operators.alloc) - .code(code) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .typeFullName(arrayBaseType) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - val valueAsts = withOrder(sizes) { (s, o) => - astsForValue(s, o, parentUnit) - }.flatten - Ast(callBlock) - .withChildren(valueAsts) - .withArgEdges(callBlock, valueAsts.flatMap(_.root)) - } - - private def astForUnaryExpr( - methodName: String, - unaryExpr: Expr, - op: Value, - order: Int, - parentUnit: soot.Unit - ): Ast = { - val callBlock = NewCall() - .name(methodName) - .methodFullName(methodName) - .code(unaryExpr.toString()) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .typeFullName(registerType(unaryExpr.getType.toQuotedString)) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - - def astForTypeRef(t: String, order: Int) = { - Seq( - Ast( - NewTypeRef() - .code(if (t.contains('.')) t.substring(t.lastIndexOf('.') + 1, t.length) else t) + val valueAsts = withOrder(sizes) { (s, o) => + astsForValue(s, o, parentUnit) + }.flatten + Ast(callBlock) + .withChildren(valueAsts) + .withArgEdges(callBlock, valueAsts.flatMap(_.root)) + end astForArrayCreateExpr + + private def astForUnaryExpr( + methodName: String, + unaryExpr: Expr, + op: Value, + order: Int, + parentUnit: soot.Unit + ): Ast = + val callBlock = NewCall() + .name(methodName) + .methodFullName(methodName) + .code(unaryExpr.toString()) + .dispatchType(DispatchTypes.STATIC_DISPATCH) .order(order) + .typeFullName(registerType(unaryExpr.getType.toQuotedString)) .argumentIndex(order) .lineNumber(line(parentUnit)) .columnNumber(column(parentUnit)) - .typeFullName(t) - ) - ) - } - - val valueAsts = unaryExpr match { - case instanceOfExpr: InstanceOfExpr => - val t = registerType(instanceOfExpr.getCheckType.toQuotedString) - astsForValue(op, 1, parentUnit) ++ astForTypeRef(t, 2) - case castExpr: CastExpr => - val t = registerType(castExpr.getCastType.toQuotedString) - astForTypeRef(t, 1) ++ astsForValue(op, 2, parentUnit) - case _ => astsForValue(op, 1, parentUnit) - } - - Ast(callBlock) - .withChildren(valueAsts) - .withArgEdges(callBlock, valueAsts.flatMap(_.root)) - } - - private def createThisNode(method: ThisRef): Ast = { - Ast( - NewIdentifier() - .name("this") - .code("this") - .typeFullName(registerType(method.getType.toQuotedString)) - .dynamicTypeHintFullName(Seq(registerType(method.getType.toQuotedString))) - .order(0) - .argumentIndex(0) - ) - } - - private def createThisNode(method: SootMethod, builder: NewNode): Ast = createThisNode(method.makeRef(), builder) - - private def createThisNode(method: SootMethodRef, builder: NewNode): Ast = { - if (!method.isStatic || method.isConstructor) { - val parentType = registerType(Try(method.getDeclaringClass.getType.toQuotedString).getOrElse("ANY")) - Ast(builder match { - case x: NewIdentifier => - x.name("this") - .code("this") - .typeFullName(parentType) - .order(0) - .argumentIndex(0) - .dynamicTypeHintFullName(Seq(parentType)) - case _: NewMethodParameterIn => - NodeBuilders.newThisParameterNode( - typeFullName = parentType, - dynamicTypeHintFullName = Seq(parentType), - line = line(Try(method.tryResolve()).getOrElse(null)) - ) - case x => x - }) - } else { - Ast() - } - } - - private def createParameterNode(parameterRef: ParameterRef, order: Int): Ast = { - val name = s"@parameter${parameterRef.getIndex}" - Ast( - NewIdentifier() - .name(name) - .code(name) - .typeFullName(registerType(parameterRef.getType.toQuotedString)) - .order(order) - .argumentIndex(order) - ) - } - - /** Creates the AST for assignment statements keeping in mind Jimple is a 3-address code language. - */ - private def astsForDefinition(assignStmt: DefinitionStmt, order: Int): Seq[Ast] = { - val initializer = assignStmt.getRightOp - val leftOp = assignStmt.getLeftOp - - val identifier = leftOp match { - case x: soot.Local => Seq(astForLocal(x, 1, assignStmt)) - case x: FieldRef => Seq(astForFieldRef(x, 1, assignStmt)) - case x => astsForValue(x, 1, assignStmt) - } - val lhsCode = identifier.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString - - val initAsts = astsForValue(initializer, 2, assignStmt) - val rhsCode = initAsts - .flatMap(_.root) - .map(_.properties.getOrElse(PropertyNames.CODE, "")) - .mkString(", ") - - val assignment = NewCall() - .name(Operators.assignment) - .methodFullName(Operators.assignment) - .code(s"$lhsCode = $rhsCode") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .argumentIndex(order) - .typeFullName(registerType(assignStmt.getLeftOp.getType.toQuotedString)) - val initializerAst = Seq(callAst(assignment, identifier ++ initAsts)) - initializerAst.toList - } - - private def astsForIfStmt(ifStmt: IfStmt, order: Int): Seq[Ast] = { - // bytecode/jimple ASTs are flat so there will not be nested bodies - val condition = astsForValue(ifStmt.getCondition, order, ifStmt) - controlTargets.put(condition, ifStmt.getTarget) - condition - } - - private def astsForGotoStmt(gotoStmt: GotoStmt, order: Int): Seq[Ast] = { - // bytecode/jimple ASTs are flat so there will not be nested bodies - val gotoAst = Seq( - Ast( - NewUnknown() - .code(s"goto ${line(gotoStmt.getTarget).getOrElse(gotoStmt.getTarget.toString())}") - .order(order) - .argumentIndex(order) - .lineNumber(line(gotoStmt)) - .columnNumber(column(gotoStmt)) - ) - ) - controlTargets.put(gotoAst, gotoStmt.getTarget) - gotoAst - } - - private def astForSwitchWithDefaultAndCondition(switchStmt: SwitchStmt, order: Int): Ast = { - val jimple = switchStmt.toString() - val totalTgts = switchStmt.getTargets.size() - val switch = NewControlStructure() - .controlStructureType(ControlStructureTypes.SWITCH) - .code(jimple.substring(0, jimple.indexOf("{") - 1)) - .lineNumber(line(switchStmt)) - .columnNumber(column(switchStmt)) - .order(order) - .argumentIndex(order) - - val conditionalAst = astsForValue(switchStmt.getKey, totalTgts + 1, switchStmt) - val defaultAst = Seq( - Ast( - NewJumpTarget() - .name("default") - .code("default:") - .order(totalTgts + 2) - .argumentIndex(totalTgts + 2) - .lineNumber(line(switchStmt.getDefaultTarget)) - .columnNumber(column(switchStmt.getDefaultTarget)) - ) - ) - Ast(switch) - .withConditionEdge(switch, conditionalAst.flatMap(_.root).head) - .withChildren(conditionalAst ++ defaultAst) - } - - private def astsForLookupSwitchStmt(lookupSwitchStmt: LookupSwitchStmt, order: Int): Seq[Ast] = { - val totalTgts = lookupSwitchStmt.getTargets.size() - val switchAst = astForSwitchWithDefaultAndCondition(lookupSwitchStmt, order) - - val tgts = for { - i <- 0 until totalTgts - if lookupSwitchStmt.getTarget(i) != lookupSwitchStmt.getDefaultTarget - } yield (lookupSwitchStmt.getLookupValue(i), lookupSwitchStmt.getTarget(i)) - val tgtAsts = tgts.map { case (lookup, target) => - Ast( - NewJumpTarget() - .name(s"case $lookup") - .code(s"case $lookup:") - .argumentIndex(lookup) - .order(lookup) - .lineNumber(line(target)) - .columnNumber(column(target)) - ) - } - - Seq( - switchAst - .withChildren(tgtAsts) - ) - } - - private def astsForTableSwitchStmt(tableSwitchStmt: SwitchStmt, order: Int): Seq[Ast] = { - val switchAst = astForSwitchWithDefaultAndCondition(tableSwitchStmt, order) - val tgtAsts = tableSwitchStmt.getTargets.asScala - .filter(x => tableSwitchStmt.getDefaultTarget != x) - .zipWithIndex - .map({ case (tgt, i) => + + def astForTypeRef(t: String, order: Int) = + Seq( + Ast( + NewTypeRef() + .code(if t.contains('.') then t.substring(t.lastIndexOf('.') + 1, t.length) + else t) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .typeFullName(t) + ) + ) + + val valueAsts = unaryExpr match + case instanceOfExpr: InstanceOfExpr => + val t = registerType(instanceOfExpr.getCheckType.toQuotedString) + astsForValue(op, 1, parentUnit) ++ astForTypeRef(t, 2) + case castExpr: CastExpr => + val t = registerType(castExpr.getCastType.toQuotedString) + astForTypeRef(t, 1) ++ astsForValue(op, 2, parentUnit) + case _ => astsForValue(op, 1, parentUnit) + + Ast(callBlock) + .withChildren(valueAsts) + .withArgEdges(callBlock, valueAsts.flatMap(_.root)) + end astForUnaryExpr + + private def createThisNode(method: ThisRef): Ast = Ast( - NewJumpTarget() - .name(s"case $i") - .code(s"case $i:") - .argumentIndex(i) - .order(i) - .lineNumber(line(tgt)) - .columnNumber(column(tgt)) + NewIdentifier() + .name("this") + .code("this") + .typeFullName(registerType(method.getType.toQuotedString)) + .dynamicTypeHintFullName(Seq(registerType(method.getType.toQuotedString))) + .order(0) + .argumentIndex(0) ) - }) - .toSeq - - Seq( - switchAst - .withChildren(tgtAsts) - ) - } - - private def astsForThrowStmt(throwStmt: ThrowStmt, order: Int): Seq[Ast] = { - val opAst = astsForValue(throwStmt.getOp, 1, throwStmt) - val throwNode = NewCall() - .name(".throw") - .methodFullName(".throw") - .lineNumber(line(throwStmt)) - .columnNumber(column(throwStmt)) - .code(s"throw new ${throwStmt.getOp.getType}()") - .order(order) - .argumentIndex(order) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - Seq( - Ast(throwNode) - .withChildren(opAst) - ) - } - - private def astsForMonitorStmt(monitorStmt: MonitorStmt, order: Int): Seq[Ast] = { - val opAst = astsForValue(monitorStmt.getOp, 1, monitorStmt) - val typeString = opAst.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString - val code = monitorStmt match { - case _: EnterMonitorStmt => s"entermonitor $typeString" - case _: ExitMonitorStmt => s"exitmonitor $typeString" - case _ => s"monitor $typeString" - } - Seq( - Ast( - NewUnknown() - .order(order) - .argumentIndex(order) - .code(code) - .lineNumber(line(monitorStmt)) - .columnNumber(column(monitorStmt)) - ).withChildren(opAst) - ) - } - - private def astForUnknownStmt(stmt: Unit, maybeOp: Option[Value], order: Int): Ast = { - val opAst = maybeOp match { - case Some(op) => astsForValue(op, 1, stmt) - case None => Seq() - } - val unknown = NewUnknown() - .order(order) - .code(stmt.toString()) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .typeFullName(registerType("void")) - Ast(unknown) - .withChildren(opAst) - } - - private def astsForReturnNode(returnStmt: ReturnStmt, order: Int): Seq[Ast] = { - val astChildren = astsForValue(returnStmt.getOp, 1, returnStmt) - val returnNode = NewReturn() - .argumentIndex(order) - .order(order) - .code(s"return ${astChildren.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString(" ")};") - .lineNumber(line(returnStmt)) - .columnNumber(column(returnStmt)) - - Seq( - Ast(returnNode) - .withChildren(astChildren) - .withArgEdges(returnNode, astChildren.flatMap(_.root)) - ) - } - - private def astsForReturnVoidNode(returnVoidStmt: ReturnVoidStmt, order: Int): Seq[Ast] = { - Seq( - Ast( - NewReturn() - .argumentIndex(order) - .order(order) - .code(s"return;") - .lineNumber(line(returnVoidStmt)) - .columnNumber(column(returnVoidStmt)) - ) - ) - } - - private def astForFieldRef(fieldRef: FieldRef, order: Int, parentUnit: soot.Unit): Ast = { - val leftOpString = fieldRef match { - case x: StaticFieldRef => x.getFieldRef.declaringClass().toString - case x: InstanceFieldRef => x.getBase.toString() - case _ => fieldRef.getFieldRef.declaringClass().toString - } - val leftOpType = fieldRef match { - case x: StaticFieldRef => x.getFieldRef.declaringClass().getType - case x: InstanceFieldRef => x.getBase.getType - case _ => fieldRef.getFieldRef.declaringClass().getType - } - - val fieldAccessBlock = NewCall() - .name(Operators.fieldAccess) - .code(s"$leftOpString.${fieldRef.getFieldRef.name()}") - .typeFullName(registerType(fieldRef.getType.toQuotedString)) - .methodFullName(Operators.fieldAccess) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - - val argAsts = Seq( - NewIdentifier() - .order(1) - .argumentIndex(1) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .name(leftOpString) - .code(leftOpString) - .typeFullName(registerType(leftOpType.toQuotedString)), - NewFieldIdentifier() - .order(2) - .argumentIndex(2) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .canonicalName(fieldRef.getFieldRef.name()) - .code(fieldRef.getFieldRef.name()) - ).map(Ast(_)) - - Ast(fieldAccessBlock) - .withChildren(argAsts) - .withArgEdges(fieldAccessBlock, argAsts.flatMap(_.root)) - } - - private def astForCaughtExceptionRef(caughtException: CaughtExceptionRef, order: Int, parentUnit: soot.Unit): Ast = { - Ast( - NewIdentifier() - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .name(caughtException.toString()) - .code(caughtException.toString()) - .typeFullName(registerType(caughtException.getType.toQuotedString)) - ) - } - - private def astForConstantExpr(constant: Constant, order: Int): Ast = { - constant match { - case x: ClassConstant => + + private def createThisNode(method: SootMethod, builder: NewNode): Ast = + createThisNode(method.makeRef(), builder) + + private def createThisNode(method: SootMethodRef, builder: NewNode): Ast = + if !method.isStatic || method.isConstructor then + val parentType = + registerType(Try(method.getDeclaringClass.getType.toQuotedString).getOrElse("ANY")) + Ast(builder match + case x: NewIdentifier => + x.name("this") + .code("this") + .typeFullName(parentType) + .order(0) + .argumentIndex(0) + .dynamicTypeHintFullName(Seq(parentType)) + case _: NewMethodParameterIn => + NodeBuilders.newThisParameterNode( + typeFullName = parentType, + dynamicTypeHintFullName = Seq(parentType), + line = line(Try(method.tryResolve()).getOrElse(null)) + ) + case x => x + ) + else + Ast() + + private def createParameterNode(parameterRef: ParameterRef, order: Int): Ast = + val name = s"@parameter${parameterRef.getIndex}" Ast( - NewLiteral() + NewIdentifier() + .name(name) + .code(name) + .typeFullName(registerType(parameterRef.getType.toQuotedString)) + .order(order) + .argumentIndex(order) + ) + + /** Creates the AST for assignment statements keeping in mind Jimple is a 3-address code + * language. + */ + private def astsForDefinition(assignStmt: DefinitionStmt, order: Int): Seq[Ast] = + val initializer = assignStmt.getRightOp + val leftOp = assignStmt.getLeftOp + + val identifier = leftOp match + case x: soot.Local => Seq(astForLocal(x, 1, assignStmt)) + case x: FieldRef => Seq(astForFieldRef(x, 1, assignStmt)) + case x => astsForValue(x, 1, assignStmt) + val lhsCode = + identifier.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString + + val initAsts = astsForValue(initializer, 2, assignStmt) + val rhsCode = initAsts + .flatMap(_.root) + .map(_.properties.getOrElse(PropertyNames.CODE, "")) + .mkString(", ") + + val assignment = NewCall() + .name(Operators.assignment) + .methodFullName(Operators.assignment) + .code(s"$lhsCode = $rhsCode") + .dispatchType(DispatchTypes.STATIC_DISPATCH) .order(order) .argumentIndex(order) - .code(s"${x.value.parseAsJavaType}.class") - .typeFullName(registerType(x.getType.toQuotedString)) + .typeFullName(registerType(assignStmt.getLeftOp.getType.toQuotedString)) + val initializerAst = Seq(callAst(assignment, identifier ++ initAsts)) + initializerAst.toList + end astsForDefinition + + private def astsForIfStmt(ifStmt: IfStmt, order: Int): Seq[Ast] = + // bytecode/jimple ASTs are flat so there will not be nested bodies + val condition = astsForValue(ifStmt.getCondition, order, ifStmt) + controlTargets.put(condition, ifStmt.getTarget) + condition + + private def astsForGotoStmt(gotoStmt: GotoStmt, order: Int): Seq[Ast] = + // bytecode/jimple ASTs are flat so there will not be nested bodies + val gotoAst = Seq( + Ast( + NewUnknown() + .code(s"goto ${line(gotoStmt.getTarget).getOrElse(gotoStmt.getTarget.toString())}") + .order(order) + .argumentIndex(order) + .lineNumber(line(gotoStmt)) + .columnNumber(column(gotoStmt)) + ) ) - case _: NullConstant => - Ast( - NewLiteral() + controlTargets.put(gotoAst, gotoStmt.getTarget) + gotoAst + + private def astForSwitchWithDefaultAndCondition(switchStmt: SwitchStmt, order: Int): Ast = + val jimple = switchStmt.toString() + val totalTgts = switchStmt.getTargets.size() + val switch = NewControlStructure() + .controlStructureType(ControlStructureTypes.SWITCH) + .code(jimple.substring(0, jimple.indexOf("{") - 1)) + .lineNumber(line(switchStmt)) + .columnNumber(column(switchStmt)) .order(order) .argumentIndex(order) - .code("null") - .typeFullName(registerType("null")) + + val conditionalAst = astsForValue(switchStmt.getKey, totalTgts + 1, switchStmt) + val defaultAst = Seq( + Ast( + NewJumpTarget() + .name("default") + .code("default:") + .order(totalTgts + 2) + .argumentIndex(totalTgts + 2) + .lineNumber(line(switchStmt.getDefaultTarget)) + .columnNumber(column(switchStmt.getDefaultTarget)) + ) ) - case _ => - Ast( - NewLiteral() + Ast(switch) + .withConditionEdge(switch, conditionalAst.flatMap(_.root).head) + .withChildren(conditionalAst ++ defaultAst) + end astForSwitchWithDefaultAndCondition + + private def astsForLookupSwitchStmt(lookupSwitchStmt: LookupSwitchStmt, order: Int): Seq[Ast] = + val totalTgts = lookupSwitchStmt.getTargets.size() + val switchAst = astForSwitchWithDefaultAndCondition(lookupSwitchStmt, order) + + val tgts = + for + i <- 0 until totalTgts + if lookupSwitchStmt.getTarget(i) != lookupSwitchStmt.getDefaultTarget + yield (lookupSwitchStmt.getLookupValue(i), lookupSwitchStmt.getTarget(i)) + val tgtAsts = tgts.map { case (lookup, target) => + Ast( + NewJumpTarget() + .name(s"case $lookup") + .code(s"case $lookup:") + .argumentIndex(lookup) + .order(lookup) + .lineNumber(line(target)) + .columnNumber(column(target)) + ) + } + + Seq( + switchAst + .withChildren(tgtAsts) + ) + end astsForLookupSwitchStmt + + private def astsForTableSwitchStmt(tableSwitchStmt: SwitchStmt, order: Int): Seq[Ast] = + val switchAst = astForSwitchWithDefaultAndCondition(tableSwitchStmt, order) + val tgtAsts = tableSwitchStmt.getTargets.asScala + .filter(x => tableSwitchStmt.getDefaultTarget != x) + .zipWithIndex + .map({ case (tgt, i) => + Ast( + NewJumpTarget() + .name(s"case $i") + .code(s"case $i:") + .argumentIndex(i) + .order(i) + .lineNumber(line(tgt)) + .columnNumber(column(tgt)) + ) + }) + .toSeq + + Seq( + switchAst + .withChildren(tgtAsts) + ) + end astsForTableSwitchStmt + + private def astsForThrowStmt(throwStmt: ThrowStmt, order: Int): Seq[Ast] = + val opAst = astsForValue(throwStmt.getOp, 1, throwStmt) + val throwNode = NewCall() + .name(".throw") + .methodFullName(".throw") + .lineNumber(line(throwStmt)) + .columnNumber(column(throwStmt)) + .code(s"throw new ${throwStmt.getOp.getType}()") + .order(order) + .argumentIndex(order) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + Seq( + Ast(throwNode) + .withChildren(opAst) + ) + + private def astsForMonitorStmt(monitorStmt: MonitorStmt, order: Int): Seq[Ast] = + val opAst = astsForValue(monitorStmt.getOp, 1, monitorStmt) + val typeString = opAst.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString + val code = monitorStmt match + case _: EnterMonitorStmt => s"entermonitor $typeString" + case _: ExitMonitorStmt => s"exitmonitor $typeString" + case _ => s"monitor $typeString" + Seq( + Ast( + NewUnknown() + .order(order) + .argumentIndex(order) + .code(code) + .lineNumber(line(monitorStmt)) + .columnNumber(column(monitorStmt)) + ).withChildren(opAst) + ) + + private def astForUnknownStmt(stmt: Unit, maybeOp: Option[Value], order: Int): Ast = + val opAst = maybeOp match + case Some(op) => astsForValue(op, 1, stmt) + case None => Seq() + val unknown = NewUnknown() .order(order) + .code(stmt.toString()) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .typeFullName(registerType("void")) + Ast(unknown) + .withChildren(opAst) + + private def astsForReturnNode(returnStmt: ReturnStmt, order: Int): Seq[Ast] = + val astChildren = astsForValue(returnStmt.getOp, 1, returnStmt) + val returnNode = NewReturn() .argumentIndex(order) - .code(constant.toString) - .typeFullName(registerType(constant.getType.toQuotedString)) + .order(order) + .code( + s"return ${astChildren.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString(" ")};" + ) + .lineNumber(line(returnStmt)) + .columnNumber(column(returnStmt)) + + Seq( + Ast(returnNode) + .withChildren(astChildren) + .withArgEdges(returnNode, astChildren.flatMap(_.root)) ) - } - } - - private def callAst(rootNode: NewNode, args: Seq[Ast]): Ast = { - Ast(rootNode) - .withChildren(args) - .withArgEdges(rootNode, args.flatMap(_.root)) - } - - private def astsForModifiers(methodDeclaration: SootMethod): Seq[Ast] = { - Seq( - if (methodDeclaration.isStatic) Some(ModifierTypes.STATIC) else None, - if (methodDeclaration.isPublic) Some(ModifierTypes.PUBLIC) else None, - if (methodDeclaration.isProtected) Some(ModifierTypes.PROTECTED) else None, - if (methodDeclaration.isPrivate) Some(ModifierTypes.PRIVATE) else None, - if (methodDeclaration.isAbstract) Some(ModifierTypes.ABSTRACT) else None, - if (methodDeclaration.isConstructor) Some(ModifierTypes.CONSTRUCTOR) else None, - if (!methodDeclaration.isFinal && !methodDeclaration.isStatic && methodDeclaration.isPublic) - Some(ModifierTypes.VIRTUAL) - else None, - if (methodDeclaration.isSynchronized) Some("SYNCHRONIZED") else None - ).flatten.map { modifier => - Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) - } - } - - private def astsForModifiers(classDeclaration: SootClass): Seq[Ast] = { - Seq( - if (classDeclaration.isStatic) Some(ModifierTypes.STATIC) else None, - if (classDeclaration.isPublic) Some(ModifierTypes.PUBLIC) else None, - if (classDeclaration.isProtected) Some(ModifierTypes.PROTECTED) else None, - if (classDeclaration.isPrivate) Some(ModifierTypes.PRIVATE) else None, - if (classDeclaration.isAbstract) Some(ModifierTypes.ABSTRACT) else None, - if (classDeclaration.isInterface) Some("INTERFACE") else None, - if (!classDeclaration.isFinal && !classDeclaration.isStatic && classDeclaration.isPublic) - Some(ModifierTypes.VIRTUAL) - else None, - if (classDeclaration.isSynchronized) Some("SYNCHRONIZED") else None - ).flatten.map { modifier => - Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) - } - } - - private def astForMethodReturn(methodDeclaration: SootMethod): Ast = { - val typeFullName = registerType(methodDeclaration.getReturnType.toQuotedString) - val methodReturnNode = NodeBuilders - .newMethodReturnNode(typeFullName, None, line(methodDeclaration), None) - .order(methodDeclaration.getParameterCount + 2) - Ast(methodReturnNode) - } - - private def createMethodNode(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int) = { - val fullName = methodFullName(typeDecl, methodDeclaration) - val methodDeclType = registerType(methodDeclaration.getReturnType.toQuotedString) - val code = if (!methodDeclaration.isConstructor) { - s"$methodDeclType ${methodDeclaration.getName}${paramListSignature(methodDeclaration, withParams = true)}" - } else { - s"${typeDecl.getClassName}${paramListSignature(methodDeclaration, withParams = true)}" - } - NewMethod() - .name(methodDeclaration.getName) - .fullName(fullName) - .code(code) - .signature(methodDeclType + paramListSignature(methodDeclaration)) - .isExternal(false) - .order(childNum) - .filename(filename) - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(typeDecl.toQuotedString) - .lineNumber(line(methodDeclaration)) - .columnNumber(column(methodDeclaration)) - } - - private def methodFullName(typeDecl: RefType, methodDeclaration: SootMethod): String = { - val typeName = registerType(typeDecl.toQuotedString) - val returnType = registerType(methodDeclaration.getReturnType.toQuotedString) - val methodName = methodDeclaration.getName - s"$typeName.$methodName:$returnType${paramListSignature(methodDeclaration)}" - } - - private def paramListSignature(methodDeclaration: SootMethod, withParams: Boolean = false) = { - val paramTypes = methodDeclaration.getParameterTypes.asScala.map(x => registerType(x.toQuotedString)) - - val paramNames = - if (!methodDeclaration.isPhantom && Try(methodDeclaration.retrieveActiveBody()).isSuccess) - methodDeclaration.retrieveActiveBody().getParameterLocals.asScala.map(_.getName) - else - paramTypes.zipWithIndex.map(x => { s"param${x._2 + 1}" }) - if (!withParams) { - "(" + paramTypes.mkString(",") + ")" - } else { - "(" + paramTypes.zip(paramNames).map(x => s"${x._1} ${x._2}").mkString(", ") + ")" - } - } -} - -object AstCreator { - def line(node: Host): Option[Integer] = { - if (node == null) None - else if (node.getJavaSourceStartLineNumber == -1) None - else Option(node.getJavaSourceStartLineNumber) - } - - def column(node: Host): Option[Integer] = { - if (node == null) None - else if (node.getJavaSourceStartColumnNumber == -1) None - else Option(node.getJavaSourceStartColumnNumber) - } - - def withOrder[T <: Any, X](nodeList: java.util.List[T])(f: (T, Int) => X): Seq[X] = { - nodeList.asScala.zipWithIndex.map { case (x, i) => - f(x, i + 1) - }.toSeq - } - - def withOrder[T <: Any, X](nodeList: Iterable[T])(f: (T, Int) => X): Seq[X] = { - nodeList.zipWithIndex.map { case (x, i) => - f(x, i + 1) - }.toSeq - } - -} + + private def astsForReturnVoidNode(returnVoidStmt: ReturnVoidStmt, order: Int): Seq[Ast] = + Seq( + Ast( + NewReturn() + .argumentIndex(order) + .order(order) + .code(s"return;") + .lineNumber(line(returnVoidStmt)) + .columnNumber(column(returnVoidStmt)) + ) + ) + + private def astForFieldRef(fieldRef: FieldRef, order: Int, parentUnit: soot.Unit): Ast = + val leftOpString = fieldRef match + case x: StaticFieldRef => x.getFieldRef.declaringClass().toString + case x: InstanceFieldRef => x.getBase.toString() + case _ => fieldRef.getFieldRef.declaringClass().toString + val leftOpType = fieldRef match + case x: StaticFieldRef => x.getFieldRef.declaringClass().getType + case x: InstanceFieldRef => x.getBase.getType + case _ => fieldRef.getFieldRef.declaringClass().getType + + val fieldAccessBlock = NewCall() + .name(Operators.fieldAccess) + .code(s"$leftOpString.${fieldRef.getFieldRef.name()}") + .typeFullName(registerType(fieldRef.getType.toQuotedString)) + .methodFullName(Operators.fieldAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + + val argAsts = Seq( + NewIdentifier() + .order(1) + .argumentIndex(1) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .name(leftOpString) + .code(leftOpString) + .typeFullName(registerType(leftOpType.toQuotedString)), + NewFieldIdentifier() + .order(2) + .argumentIndex(2) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .canonicalName(fieldRef.getFieldRef.name()) + .code(fieldRef.getFieldRef.name()) + ).map(Ast(_)) + + Ast(fieldAccessBlock) + .withChildren(argAsts) + .withArgEdges(fieldAccessBlock, argAsts.flatMap(_.root)) + end astForFieldRef + + private def astForCaughtExceptionRef( + caughtException: CaughtExceptionRef, + order: Int, + parentUnit: soot.Unit + ): Ast = + Ast( + NewIdentifier() + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .name(caughtException.toString()) + .code(caughtException.toString()) + .typeFullName(registerType(caughtException.getType.toQuotedString)) + ) + + private def astForConstantExpr(constant: Constant, order: Int): Ast = + constant match + case x: ClassConstant => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code(s"${x.value.parseAsJavaType}.class") + .typeFullName(registerType(x.getType.toQuotedString)) + ) + case _: NullConstant => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code("null") + .typeFullName(registerType("null")) + ) + case _ => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code(constant.toString) + .typeFullName(registerType(constant.getType.toQuotedString)) + ) + + private def callAst(rootNode: NewNode, args: Seq[Ast]): Ast = + Ast(rootNode) + .withChildren(args) + .withArgEdges(rootNode, args.flatMap(_.root)) + + private def astsForModifiers(methodDeclaration: SootMethod): Seq[Ast] = + Seq( + if methodDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, + if methodDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, + if methodDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, + if methodDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, + if methodDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, + if methodDeclaration.isConstructor then Some(ModifierTypes.CONSTRUCTOR) else None, + if !methodDeclaration.isFinal && !methodDeclaration.isStatic && methodDeclaration.isPublic + then + Some(ModifierTypes.VIRTUAL) + else None, + if methodDeclaration.isSynchronized then Some("SYNCHRONIZED") else None + ).flatten.map { modifier => + Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) + } + + private def astsForModifiers(classDeclaration: SootClass): Seq[Ast] = + Seq( + if classDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, + if classDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, + if classDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, + if classDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, + if classDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, + if classDeclaration.isInterface then Some("INTERFACE") else None, + if !classDeclaration.isFinal && !classDeclaration.isStatic && classDeclaration.isPublic + then + Some(ModifierTypes.VIRTUAL) + else None, + if classDeclaration.isSynchronized then Some("SYNCHRONIZED") else None + ).flatten.map { modifier => + Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) + } + + private def astForMethodReturn(methodDeclaration: SootMethod): Ast = + val typeFullName = registerType(methodDeclaration.getReturnType.toQuotedString) + val methodReturnNode = NodeBuilders + .newMethodReturnNode(typeFullName, None, line(methodDeclaration), None) + .order(methodDeclaration.getParameterCount + 2) + Ast(methodReturnNode) + + private def createMethodNode(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int) = + val fullName = methodFullName(typeDecl, methodDeclaration) + val methodDeclType = registerType(methodDeclaration.getReturnType.toQuotedString) + val code = if !methodDeclaration.isConstructor then + s"$methodDeclType ${methodDeclaration.getName}${paramListSignature(methodDeclaration, withParams = true)}" + else + s"${typeDecl.getClassName}${paramListSignature(methodDeclaration, withParams = true)}" + NewMethod() + .name(methodDeclaration.getName) + .fullName(fullName) + .code(code) + .signature(methodDeclType + paramListSignature(methodDeclaration)) + .isExternal(false) + .order(childNum) + .filename(filename) + .astParentType(NodeTypes.TYPE_DECL) + .astParentFullName(typeDecl.toQuotedString) + .lineNumber(line(methodDeclaration)) + .columnNumber(column(methodDeclaration)) + end createMethodNode + + private def methodFullName(typeDecl: RefType, methodDeclaration: SootMethod): String = + val typeName = registerType(typeDecl.toQuotedString) + val returnType = registerType(methodDeclaration.getReturnType.toQuotedString) + val methodName = methodDeclaration.getName + s"$typeName.$methodName:$returnType${paramListSignature(methodDeclaration)}" + + private def paramListSignature(methodDeclaration: SootMethod, withParams: Boolean = false) = + val paramTypes = + methodDeclaration.getParameterTypes.asScala.map(x => registerType(x.toQuotedString)) + + val paramNames = + if !methodDeclaration.isPhantom && Try(methodDeclaration.retrieveActiveBody()).isSuccess + then + methodDeclaration.retrieveActiveBody().getParameterLocals.asScala.map(_.getName) + else + paramTypes.zipWithIndex.map(x => s"param${x._2 + 1}") + if !withParams then + "(" + paramTypes.mkString(",") + ")" + else + "(" + paramTypes.zip(paramNames).map(x => s"${x._1} ${x._2}").mkString(", ") + ")" +end AstCreator + +object AstCreator: + def line(node: Host): Option[Integer] = + if node == null then None + else if node.getJavaSourceStartLineNumber == -1 then None + else Option(node.getJavaSourceStartLineNumber) + + def column(node: Host): Option[Integer] = + if node == null then None + else if node.getJavaSourceStartColumnNumber == -1 then None + else Option(node.getJavaSourceStartColumnNumber) + + def withOrder[T <: Any, X](nodeList: java.util.List[T])(f: (T, Int) => X): Seq[X] = + nodeList.asScala.zipWithIndex.map { case (x, i) => + f(x, i + 1) + }.toSeq + + def withOrder[T <: Any, X](nodeList: Iterable[T])(f: (T, Int) => X): Seq[X] = + nodeList.zipWithIndex.map { case (x, i) => + f(x, i + 1) + }.toSeq +end AstCreator /** String extensions for strings describing JVM operators. */ -implicit class JvmStringOpts(s: String) { - - /** Parses the string as a ASM Java type descriptor and returns a fully qualified type. Also converts symbols such as - * I to int. - * @return - */ - def parseAsJavaType: String = Type.getType(s).getClassName.replaceAll("/", ".") +implicit class JvmStringOpts(s: String): -} + /** Parses the string as a ASM Java type descriptor and returns a fully qualified type. Also + * converts symbols such as I to int. + * @return + */ + def parseAsJavaType: String = Type.getType(s).getClassName.replaceAll("/", ".") diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala index b16291e3..d2dccc14 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala @@ -4,35 +4,34 @@ import better.files.File import io.appthreat.x2cpg.passes.frontend.XConfigFileCreationPass import io.shiftleft.codepropertygraph.Cpg -class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg) { +class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - extensionFilter(".properties"), - // Velocity files, see https://velocity.apache.org - extensionFilter(".vm"), - // For Terraform secrets - extensionFilter(".tf"), - extensionFilter(".tfvars"), - // PLAY - pathEndFilter("routes"), - pathEndFilter("application.conf"), - // SERVLET - pathEndFilter("web.xml"), - // JSF - pathEndFilter("faces-config.xml"), - // STRUTS - pathEndFilter("struts.xml"), - // DIRECT WEB REMOTING - pathEndFilter("dwr.xml"), - // BUILD SYSTEM - pathEndFilter("build.gradle"), - pathEndFilter("build.gradle.kts"), - // ANDROID - pathEndFilter("AndroidManifest.xml"), - // Bom - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json") - ) - -} + override val configFileFilters: List[File => Boolean] = List( + extensionFilter(".properties"), + // Velocity files, see https://velocity.apache.org + extensionFilter(".vm"), + // For Terraform secrets + extensionFilter(".tf"), + extensionFilter(".tfvars"), + // PLAY + pathEndFilter("routes"), + pathEndFilter("application.conf"), + // SERVLET + pathEndFilter("web.xml"), + // JSF + pathEndFilter("faces-config.xml"), + // STRUTS + pathEndFilter("struts.xml"), + // DIRECT WEB REMOTING + pathEndFilter("dwr.xml"), + // BUILD SYSTEM + pathEndFilter("build.gradle"), + pathEndFilter("build.gradle.kts"), + // ANDROID + pathEndFilter("AndroidManifest.xml"), + // Bom + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json") + ) +end ConfigFileCreationPass diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala index 97422b8e..1cf94125 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala @@ -6,17 +6,17 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Declaration, Method} import io.shiftleft.passes.ConcurrentWriterCpgPass import io.shiftleft.semanticcpg.language.* -/** Links declarations to their identifier nodes. Due to the flat AST of bytecode, we don't need to account for varying - * scope. +/** Links declarations to their identifier nodes. Due to the flat AST of bytecode, we don't need to + * account for varying scope. */ -class DeclarationRefPass(atom: Cpg) extends ConcurrentWriterCpgPass[Method](atom) { +class DeclarationRefPass(atom: Cpg) extends ConcurrentWriterCpgPass[Method](atom): - override def generateParts(): Array[Method] = atom.method.toArray + override def generateParts(): Array[Method] = atom.method.toArray - override def runOnPart(builder: DiffGraphBuilder, part: Method): Unit = { - val identifiers = part.ast.isIdentifier.toList - val declarations = (part.parameter ++ part.block.astChildren.isLocal).collectAll[Declaration].l - declarations.foreach(d => identifiers.nameExact(d.name).foreach(builder.addEdge(_, d, EdgeTypes.REF))) - } - -} + override def runOnPart(builder: DiffGraphBuilder, part: Method): Unit = + val identifiers = part.ast.isIdentifier.toList + val declarations = + (part.parameter ++ part.block.astChildren.isLocal).collectAll[Declaration].l + declarations.foreach(d => + identifiers.nameExact(d.name).foreach(builder.addEdge(_, d, EdgeTypes.REF)) + ) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala index 9c24c20b..e89c971d 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala @@ -7,25 +7,23 @@ import io.shiftleft.passes.ConcurrentWriterCpgPass import org.slf4j.LoggerFactory import soot.{Scene, SootClass, SourceLocator} -/** Creates the AST layer from the given class file and stores all types in the given global parameter. +/** Creates the AST layer from the given class file and stores all types in the given global + * parameter. */ -class SootAstCreationPass(cpg: Cpg, config: Config) extends ConcurrentWriterCpgPass[SootClass](cpg) { +class SootAstCreationPass(cpg: Cpg, config: Config) extends ConcurrentWriterCpgPass[SootClass](cpg): - val global: Global = new Global() - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + val global: Global = new Global() + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - override def generateParts(): Array[_ <: AnyRef] = Scene.v().getApplicationClasses.toArray() + override def generateParts(): Array[? <: AnyRef] = Scene.v().getApplicationClasses.toArray() - override def runOnPart(builder: DiffGraphBuilder, part: SootClass): Unit = { - val jimpleFile = SourceLocator.v().getSourceForClass(part.getName) - try { - // sootClass.setApplicationClass() - val localDiff = new AstCreator(jimpleFile, part, global)(config.schemaValidation).createAst() - builder.absorb(localDiff) - } catch { - case e: Exception => - logger.warn(s"Cannot parse: $part", e) - } - } - -} + override def runOnPart(builder: DiffGraphBuilder, part: SootClass): Unit = + val jimpleFile = SourceLocator.v().getSourceForClass(part.getName) + try + // sootClass.setApplicationClass() + val localDiff = + new AstCreator(jimpleFile, part, global)(config.schemaValidation).createAst() + builder.absorb(localDiff) + catch + case e: Exception => + logger.warn(s"Cannot parse: $part", e) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala index 9fb9294a..0553887c 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala @@ -11,191 +11,193 @@ import scala.util.{Failure, Left, Success, Try} /** Responsible for handling JAR unpacking and handling the temporary build directory. */ -object ProgramHandlingUtil { +object ProgramHandlingUtil: - private val logger = LoggerFactory.getLogger(ProgramHandlingUtil.getClass) + private val logger = LoggerFactory.getLogger(ProgramHandlingUtil.getClass) - /** Common properties of a File and ZipEntry, used to determine whether a file in a directory or an entry in an - * archive is worth emitting/extracting - */ - sealed class Entry(entry: Either[File, ZipEntry]) { - - def this(file: File) = this(Left(file)) - def this(entry: ZipEntry) = this(Right(entry)) - private def file: File = entry.fold(identity, e => File(e.getName)) - def name: String = file.name - def extension: Option[String] = file.extension - def isDirectory: Boolean = entry.fold(_.isDirectory, _.isDirectory) - def maybeRegularFile(): Boolean = entry.fold(_.isRegularFile, !_.isDirectory) - - /** Determines whether a zip entry is potentially malicious. + /** Common properties of a File and ZipEntry, used to determine whether a file in a directory or + * an entry in an archive is worth emitting/extracting + */ + sealed class Entry(entry: Either[File, ZipEntry]): + + def this(file: File) = this(Left(file)) + def this(entry: ZipEntry) = this(Right(entry)) + private def file: File = entry.fold(identity, e => File(e.getName)) + def name: String = file.name + def extension: Option[String] = file.extension + def isDirectory: Boolean = entry.fold(_.isDirectory, _.isDirectory) + def maybeRegularFile(): Boolean = entry.fold(_.isRegularFile, !_.isDirectory) + + /** Determines whether a zip entry is potentially malicious. + * @return + * whether the entry is a ZipEntry and uses '..' in it's components + */ + // Note that we consider either type of path separator as although the spec say that only + // unix separators are to be used, zip files in the wild may vary. + def isZipSlip: Boolean = entry.fold(_ => false, _.getName.split("[/\\\\]").contains("..")) + + /** Process files that may lead to more files to process or to emit a resulting value of [[A]] + * + * @param src + * The file/directory to traverse + * @param emitOrUnpack + * A function that takes a file and either emits a value or returns more files to traverse + * @tparam A + * The type of emitted values * @return - * whether the entry is a ZipEntry and uses '..' in it's components + * The emitted values */ - // Note that we consider either type of path separator as although the spec say that only - // unix separators are to be used, zip files in the wild may vary. - def isZipSlip: Boolean = entry.fold(_ => false, _.getName.split("[/\\\\]").contains("..")) - } - - /** Process files that may lead to more files to process or to emit a resulting value of [[A]] - * - * @param src - * The file/directory to traverse - * @param emitOrUnpack - * A function that takes a file and either emits a value or returns more files to traverse - * @tparam A - * The type of emitted values - * @return - * The emitted values - */ - private def unfoldArchives[A](src: File, emitOrUnpack: File => Either[A, List[File]]): IterableOnce[A] = { - // TODO: add recursion depth limit - emitOrUnpack(src) match { - case Left(a) => Seq(a) - case Right(disposeFiles) => disposeFiles.flatMap(x => unfoldArchives(x, emitOrUnpack)) - } - } - - /** Find
.class
files, including those inside archives. - * - * @param src - * The file/directory to search. - * @param tmpDir - * A temporary directory for extracted archives - * @param isArchive - * Whether an entry is an archive to extract - * @param isClass - * Whether an entry is a class file - * @return - * The list of class files found, which may either be in [[src]] or in an extracted archive under [[tmpDir]] - */ - private def extractClassesToTmp( - src: File, - tmpDir: File, - isArchive: Entry => Boolean, - isClass: Entry => Boolean, - recurse: Boolean - ): IterableOnce[ClassFile] = { - - def shouldExtract(e: Entry) = !e.isZipSlip && e.maybeRegularFile() && (isArchive(e) || isClass(e)) - unfoldArchives( - src, - { - case f if isClass(Entry(f)) => - Left(ClassFile(f)) - case f if f.isDirectory() => - val files = f.listRecursively.filterNot(_.isDirectory).toList - Right(files) - case f if isArchive(Entry(f)) && (recurse || f == src) => - val xTmp = File.newTemporaryDirectory("extract-archive-", parent = Some(tmpDir)) - val unzipDirs = Try(f.unzipTo(xTmp, e => shouldExtract(Entry(e)))) match { - case Success(dir) => List(dir) - case Failure(e) => - logger.warn(s"Failed to extract archive", e) - List.empty - } - Right(unzipDirs) - case _ => - Right(List.empty) - } - ) - } - - object ClassFile { - private def getPackagePathFromByteCode(is: InputStream): Option[String] = { - val cr = new ClassReader(is) - sealed class ClassNameVisitor extends ClassVisitor(Opcodes.ASM9) { - var path: Option[String] = None - override def visit( - version: Int, - access: Int, - name: String, - signature: String, - superName: String, - interfaces: Array[String] - ): Unit = { - path = Some(name) - } - } - val rootVisitor = new ClassNameVisitor() - cr.accept(rootVisitor, SKIP_CODE) - rootVisitor.path - } - - /** Attempt to retrieve the package path from JVM bytecode. + private def unfoldArchives[A]( + src: File, + emitOrUnpack: File => Either[A, List[File]] + ): IterableOnce[A] = + // TODO: add recursion depth limit + emitOrUnpack(src) match + case Left(a) => Seq(a) + case Right(disposeFiles) => disposeFiles.flatMap(x => unfoldArchives(x, emitOrUnpack)) + + /** Find
.class
files, including those inside archives. * - * @param file - * The class file + * @param src + * The file/directory to search. + * @param tmpDir + * A temporary directory for extracted archives + * @param isArchive + * Whether an entry is an archive to extract + * @param isClass + * Whether an entry is a class file * @return - * The package path if successfully retrieved + * The list of class files found, which may either be in [[src]] or in an extracted archive + * under [[tmpDir]] */ - private def getPackagePathFromByteCode(file: File): Option[String] = - Try(file.fileInputStream.apply(getPackagePathFromByteCode)) - .recover { case e: Throwable => - logger.debug(s"Error reading class file ${file.canonicalPath}", e) - None - } - .getOrElse(None) - } - sealed class ClassFile(val file: File, val packagePath: Option[String]) { - def this(file: File) = this(file, ClassFile.getPackagePathFromByteCode(file)) - - private val components: Option[Array[String]] = packagePath.map(_.split("/")) - - val fullyQualifiedClassName: Option[String] = components.map(_.mkString(".")) - - /** Copy the class file to its package path relative to [[destDir]]. This will overwrite a class file at the - * destination if it exists. + private def extractClassesToTmp( + src: File, + tmpDir: File, + isArchive: Entry => Boolean, + isClass: Entry => Boolean, + recurse: Boolean + ): IterableOnce[ClassFile] = + + def shouldExtract(e: Entry) = + !e.isZipSlip && e.maybeRegularFile() && (isArchive(e) || isClass(e)) + unfoldArchives( + src, + { + case f if isClass(Entry(f)) => + Left(ClassFile(f)) + case f if f.isDirectory() => + val files = f.listRecursively.filterNot(_.isDirectory).toList + Right(files) + case f if isArchive(Entry(f)) && (recurse || f == src) => + val xTmp = File.newTemporaryDirectory("extract-archive-", parent = Some(tmpDir)) + val unzipDirs = Try(f.unzipTo(xTmp, e => shouldExtract(Entry(e)))) match + case Success(dir) => List(dir) + case Failure(e) => + logger.warn(s"Failed to extract archive", e) + List.empty + Right(unzipDirs) + case _ => + Right(List.empty) + } + ) + end extractClassesToTmp + + object ClassFile: + private def getPackagePathFromByteCode(is: InputStream): Option[String] = + val cr = new ClassReader(is) + sealed class ClassNameVisitor extends ClassVisitor(Opcodes.ASM9): + var path: Option[String] = None + override def visit( + version: Int, + access: Int, + name: String, + signature: String, + superName: String, + interfaces: Array[String] + ): Unit = + path = Some(name) + val rootVisitor = new ClassNameVisitor() + cr.accept(rootVisitor, SKIP_CODE) + rootVisitor.path + + /** Attempt to retrieve the package path from JVM bytecode. + * + * @param file + * The class file + * @return + * The package path if successfully retrieved + */ + private def getPackagePathFromByteCode(file: File): Option[String] = + Try(file.fileInputStream.apply(getPackagePathFromByteCode)) + .recover { case e: Throwable => + logger.debug(s"Error reading class file ${file.canonicalPath}", e) + None + } + .getOrElse(None) + end ClassFile + sealed class ClassFile(val file: File, val packagePath: Option[String]): + def this(file: File) = this(file, ClassFile.getPackagePathFromByteCode(file)) + + private val components: Option[Array[String]] = packagePath.map(_.split("/")) + + val fullyQualifiedClassName: Option[String] = components.map(_.mkString(".")) + + /** Copy the class file to its package path relative to [[destDir]]. This will overwrite a + * class file at the destination if it exists. + * @param destDir + * The directory in which to place the class file + * @return + * The class file at the destination if the package path could be retrieved from the its + * bytecode + */ + def copyToPackageLayoutIn(destDir: File): Option[ClassFile] = + packagePath + .map { path => + val destClass = destDir / s"$path.class" + if destClass.exists() then + logger.warn(s"Overwriting class file: ${destClass.path.toAbsolutePath}") + destClass.parent.createDirectories() + ClassFile( + file.copyTo(destClass)(File.CopyOptions(overwrite = true)), + packagePath + ) + } + .orElse { + logger.warn( + s"Missing package path for ${file.canonicalPath}. Failed to copy to ${destDir.canonicalPath}" + ) + None + } + end ClassFile + + /** Find
.class
files, including those inside archives and copy them to their package + * path location relative to [[destDir]] + * + * @param src + * The file/directory to search. * @param destDir - * The directory in which to place the class file + * The directory in which to place the class files + * @param isArchive + * Whether an entry is an archive to extract + * @param isClass + * Whether an entry is a class file + * @param recurse + * Whether to unpack recursively * @return - * The class file at the destination if the package path could be retrieved from the its bytecode + * The copied class files in destDir */ - def copyToPackageLayoutIn(destDir: File): Option[ClassFile] = - packagePath - .map { path => - val destClass = destDir / s"$path.class" - if (destClass.exists()) { - logger.warn(s"Overwriting class file: ${destClass.path.toAbsolutePath}") - } - destClass.parent.createDirectories() - ClassFile(file.copyTo(destClass)(File.CopyOptions(overwrite = true)), packagePath) - } - .orElse { - logger.warn(s"Missing package path for ${file.canonicalPath}. Failed to copy to ${destDir.canonicalPath}") - None - } - } - - /** Find
.class
files, including those inside archives and copy them to their package path location - * relative to [[destDir]] - * - * @param src - * The file/directory to search. - * @param destDir - * The directory in which to place the class files - * @param isArchive - * Whether an entry is an archive to extract - * @param isClass - * Whether an entry is a class file - * @param recurse - * Whether to unpack recursively - * @return - * The copied class files in destDir - */ - def extractClassesInPackageLayout( - src: File, - destDir: File, - isClass: Entry => Boolean, - isArchive: Entry => Boolean, - recurse: Boolean - ): List[ClassFile] = - File - .temporaryDirectory("extract-classes-") - .apply(tmpDir => - extractClassesToTmp(src, tmpDir, isArchive, isClass, recurse: Boolean).iterator - .flatMap(_.copyToPackageLayoutIn(destDir)) - .toList - ) - -} + def extractClassesInPackageLayout( + src: File, + destDir: File, + isClass: Entry => Boolean, + isArchive: Entry => Boolean, + recurse: Boolean + ): List[ClassFile] = + File + .temporaryDirectory("extract-classes-") + .apply(tmpDir => + extractClassesToTmp(src, tmpDir, isArchive, isClass, recurse: Boolean).iterator + .flatMap(_.copyToPackageLayoutIn(destDir)) + .toList + ) +end ProgramHandlingUtil diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala index 37518972..1fcf26c3 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala @@ -5,19 +5,19 @@ import io.appthreat.jssrc2cpg.utils.AstGenRunner import io.appthreat.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import JsSrc2Cpg.postProcessingPasses import io.appthreat.jssrc2cpg.passes.{ - AstCreationPass, - BuiltinTypesPass, - ConfigPass, - ConstClosurePass, - DependenciesPass, - ImportResolverPass, - ImportsPass, - JavaScriptInheritanceNamePass, - JavaScriptTypeHintCallLinker, - JavaScriptTypeRecoveryPass, - JsMetaDataPass, - PrivateKeyFilePass, - TypeNodePass + AstCreationPass, + BuiltinTypesPass, + ConfigPass, + ConstClosurePass, + DependenciesPass, + ImportResolverPass, + ImportsPass, + JavaScriptInheritanceNamePass, + JavaScriptTypeHintCallLinker, + JavaScriptTypeRecoveryPass, + JsMetaDataPass, + PrivateKeyFilePass, + TypeNodePass } import io.appthreat.jssrc2cpg.passes.* import io.appthreat.x2cpg.X2Cpg.withNewEmptyCpg @@ -30,57 +30,54 @@ import io.shiftleft.semanticcpg.layers.LayerCreatorContext import scala.util.Try -class JsSrc2Cpg extends X2CpgFrontend[Config] { +class JsSrc2Cpg extends X2CpgFrontend[Config]: - private val report: Report = new Report() + private val report: Report = new Report() - def createCpg(config: Config): Try[Cpg] = { - withNewEmptyCpg(config.outputPath, config) { (cpg, config) => - File.usingTemporaryDirectory("jssrc2cpgOut") { tmpDir => - val astGenResult = new AstGenRunner(config).execute(tmpDir) - val hash = HashUtil.sha256(astGenResult.parsedFiles.map { case (_, file) => File(file).path }) + def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config) { (cpg, config) => + File.usingTemporaryDirectory("jssrc2cpgOut") { tmpDir => + val astGenResult = new AstGenRunner(config).execute(tmpDir) + val hash = HashUtil.sha256(astGenResult.parsedFiles.map { case (_, file) => + File(file).path + }) - val astCreationPass = new AstCreationPass(cpg, astGenResult, config, report)(config.schemaValidation) - astCreationPass.createAndApply() + val astCreationPass = + new AstCreationPass(cpg, astGenResult, config, report)(config.schemaValidation) + astCreationPass.createAndApply() - new TypeNodePass(astCreationPass.allUsedTypes(), cpg).createAndApply() - new JsMetaDataPass(cpg, hash, config.inputPath).createAndApply() - new BuiltinTypesPass(cpg).createAndApply() - new DependenciesPass(cpg, config).createAndApply() - new ConfigPass(cpg, config, report).createAndApply() - new PrivateKeyFilePass(cpg, config, report).createAndApply() - new ImportsPass(cpg).createAndApply() + new TypeNodePass(astCreationPass.allUsedTypes(), cpg).createAndApply() + new JsMetaDataPass(cpg, hash, config.inputPath).createAndApply() + new BuiltinTypesPass(cpg).createAndApply() + new DependenciesPass(cpg, config).createAndApply() + new ConfigPass(cpg, config, report).createAndApply() + new PrivateKeyFilePass(cpg, config, report).createAndApply() + new ImportsPass(cpg).createAndApply() - report.print() - } - } - } + report.print() + } + } - // This method is intended for internal use only and may be removed at any time. - def createCpgWithAllOverlays(config: Config): Try[Cpg] = { - val maybeCpg = createCpgWithOverlays(config) - maybeCpg.map { cpg => - new OssDataFlow(new OssDataFlowOptions()).run(new LayerCreatorContext(cpg)) - postProcessingPasses(cpg, Option(config)).foreach(_.createAndApply()) - cpg - } - } + // This method is intended for internal use only and may be removed at any time. + def createCpgWithAllOverlays(config: Config): Try[Cpg] = + val maybeCpg = createCpgWithOverlays(config) + maybeCpg.map { cpg => + new OssDataFlow(new OssDataFlowOptions()).run(new LayerCreatorContext(cpg)) + postProcessingPasses(cpg, Option(config)).foreach(_.createAndApply()) + cpg + } +end JsSrc2Cpg -} - -object JsSrc2Cpg { +object JsSrc2Cpg: - def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = { - val typeRecoveryConfig = config - .map(c => XTypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) - .getOrElse(XTypeRecoveryConfig()) - List( - new JavaScriptInheritanceNamePass(cpg), - new ConstClosurePass(cpg), - new ImportResolverPass(cpg), - new JavaScriptTypeRecoveryPass(cpg, typeRecoveryConfig), - new JavaScriptTypeHintCallLinker(cpg) - ) - } - -} + def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = + val typeRecoveryConfig = config + .map(c => XTypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) + .getOrElse(XTypeRecoveryConfig()) + List( + new JavaScriptInheritanceNamePass(cpg), + new ConstClosurePass(cpg), + new ImportResolverPass(cpg), + new JavaScriptTypeRecoveryPass(cpg, typeRecoveryConfig), + new JavaScriptTypeHintCallLinker(cpg) + ) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala index b3d6006d..09ac2d94 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala @@ -8,41 +8,32 @@ import scopt.OParser import java.nio.file.Paths -final case class Config(tsTypes: Boolean = true) extends X2CpgConfig[Config] with TypeRecoveryParserConfig[Config] { - - def withTsTypes(value: Boolean): Config = { - copy(tsTypes = value).withInheritedFields(this) - } - -} - -object Frontend { - implicit val defaultConfig: Config = Config() - - val cmdLineParser: OParser[Unit, Config] = { - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName("jssrc2cpg"), - opt[Unit]("no-tsTypes") - .hidden() - .action((_, c) => c.withTsTypes(false)) - .text("disable generation of types via Typescript"), - XTypeRecovery.parserOptions - ) - } - -} - -object Main extends X2CpgMain(cmdLineParser, new JsSrc2Cpg()) { - - def run(config: Config, jssrc2cpg: JsSrc2Cpg): Unit = { - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - if (Environment.pathExists(absPath)) { - jssrc2cpg.run(config.withInputPath(absPath)) - } else { - System.exit(1) - } - } - -} +final case class Config(tsTypes: Boolean = true) extends X2CpgConfig[Config] + with TypeRecoveryParserConfig[Config]: + + def withTsTypes(value: Boolean): Config = + copy(tsTypes = value).withInheritedFields(this) + +object Frontend: + implicit val defaultConfig: Config = Config() + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("jssrc2cpg"), + opt[Unit]("no-tsTypes") + .hidden() + .action((_, c) => c.withTsTypes(false)) + .text("disable generation of types via Typescript"), + XTypeRecovery.parserOptions + ) + +object Main extends X2CpgMain(cmdLineParser, new JsSrc2Cpg()): + + def run(config: Config, jssrc2cpg: JsSrc2Cpg): Unit = + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + if Environment.pathExists(absPath) then + jssrc2cpg.run(config.withInputPath(absPath)) + else + System.exit(1) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala index 025df696..aaa73ac2 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala @@ -8,7 +8,12 @@ import io.appthreat.jssrc2cpg.parser.BabelAst.* import io.appthreat.jssrc2cpg.parser.BabelJsonParser.ParseResult import io.appthreat.x2cpg.datastructures.Stack.* import io.appthreat.x2cpg.utils.NodeBuilders.newMethodReturnNode -import io.appthreat.x2cpg.{Ast, AstCreatorBase, ValidationMode, AstNodeBuilder as X2CpgAstNodeBuilder} +import io.appthreat.x2cpg.{ + Ast, + AstCreatorBase, + ValidationMode, + AstNodeBuilder as X2CpgAstNodeBuilder +} import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes} import io.shiftleft.codepropertygraph.generated.nodes.NewBlock import io.shiftleft.codepropertygraph.generated.nodes.NewFile @@ -39,218 +44,234 @@ class AstCreator( with AstNodeBuilder with TypeHelper with AstCreatorHelper - with X2CpgAstNodeBuilder[BabelNodeInfo, AstCreator] { + with X2CpgAstNodeBuilder[BabelNodeInfo, AstCreator]: - protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) + protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) - protected val scope = new Scope() + protected val scope = new Scope() - // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST - // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. - // To achieve this we need this extra stack. - protected val methodAstParentStack = new Stack[NewNode]() - protected val typeRefIdStack = new Stack[NewTypeRef] - protected val dynamicInstanceTypeStack = new Stack[String] - protected val localAstParentStack = new Stack[NewBlock]() - protected val rootTypeDecl = new Stack[NewTypeDecl]() - protected val typeFullNameToPostfix = mutable.HashMap.empty[String, Int] - protected val functionNodeToNameAndFullName = mutable.HashMap.empty[BabelNodeInfo, (String, String)] - protected val usedVariableNames = mutable.HashMap.empty[String, Int] - protected val seenAliasTypes = mutable.HashSet.empty[NewTypeDecl] - protected val functionFullNames = mutable.HashSet.empty[String] + // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST + // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. + // To achieve this we need this extra stack. + protected val methodAstParentStack = new Stack[NewNode]() + protected val typeRefIdStack = new Stack[NewTypeRef] + protected val dynamicInstanceTypeStack = new Stack[String] + protected val localAstParentStack = new Stack[NewBlock]() + protected val rootTypeDecl = new Stack[NewTypeDecl]() + protected val typeFullNameToPostfix = mutable.HashMap.empty[String, Int] + protected val functionNodeToNameAndFullName = + mutable.HashMap.empty[BabelNodeInfo, (String, String)] + protected val usedVariableNames = mutable.HashMap.empty[String, Int] + protected val seenAliasTypes = mutable.HashSet.empty[NewTypeDecl] + protected val functionFullNames = mutable.HashSet.empty[String] - // we track line and column numbers manually because astgen / @babel-parser sometimes - // fails to deliver them at all - strange, but this even happens with its latest version - protected val (positionToLineNumberMapping, positionToFirstPositionInLineMapping) = - positionLookupTables(parserResult.fileContent) + // we track line and column numbers manually because astgen / @babel-parser sometimes + // fails to deliver them at all - strange, but this even happens with its latest version + protected val (positionToLineNumberMapping, positionToFirstPositionInLineMapping) = + positionLookupTables(parserResult.fileContent) - override def createAst(): DiffGraphBuilder = { - val fileNode = NewFile().name(parserResult.filename).order(1) - val namespaceBlock = globalNamespaceBlock() - methodAstParentStack.push(namespaceBlock) - val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(createProgramMethod())) - Ast.storeInDiffGraph(ast, diffGraph) - createVariableReferenceLinks() - diffGraph - } + override def createAst(): DiffGraphBuilder = + val fileNode = NewFile().name(parserResult.filename).order(1) + val namespaceBlock = globalNamespaceBlock() + methodAstParentStack.push(namespaceBlock) + val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(createProgramMethod())) + Ast.storeInDiffGraph(ast, diffGraph) + createVariableReferenceLinks() + diffGraph - private def createProgramMethod(): Ast = { - val path = parserResult.filename - val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) - val lineNumber = astNodeInfo.lineNumber - val columnNumber = astNodeInfo.columnNumber - val lineNumberEnd = astNodeInfo.lineNumberEnd - val columnNumberEnd = astNodeInfo.columnNumberEnd - val name = ":program" - val fullName = s"$path:$name" + private def createProgramMethod(): Ast = + val path = parserResult.filename + val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) + val lineNumber = astNodeInfo.lineNumber + val columnNumber = astNodeInfo.columnNumber + val lineNumberEnd = astNodeInfo.lineNumberEnd + val columnNumberEnd = astNodeInfo.columnNumberEnd + val name = ":program" + val fullName = s"$path:$name" - val programMethod = - NewMethod() - .order(1) - .name(name) - .code(name) - .fullName(fullName) - .filename(path) - .lineNumber(lineNumber) - .lineNumberEnd(lineNumberEnd) - .columnNumber(columnNumber) - .columnNumberEnd(columnNumberEnd) - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(fullName) + val programMethod = + NewMethod() + .order(1) + .name(name) + .code(name) + .fullName(fullName) + .filename(path) + .lineNumber(lineNumber) + .lineNumberEnd(lineNumberEnd) + .columnNumber(columnNumber) + .columnNumberEnd(columnNumberEnd) + .astParentType(NodeTypes.TYPE_DECL) + .astParentFullName(fullName) - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst(astNodeInfo, programMethod, methodAstParentStack.head, name, fullName, path) - rootTypeDecl.push(functionTypeAndTypeDeclAst.nodes.head.asInstanceOf[NewTypeDecl]) + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + astNodeInfo, + programMethod, + methodAstParentStack.head, + name, + fullName, + path + ) + rootTypeDecl.push(functionTypeAndTypeDeclAst.nodes.head.asInstanceOf[NewTypeDecl]) - methodAstParentStack.push(programMethod) + methodAstParentStack.push(programMethod) - val blockNode = NewBlock().typeFullName(Defines.Any) + val blockNode = NewBlock().typeFullName(Defines.Any) - scope.pushNewMethodScope(fullName, name, blockNode, None) - localAstParentStack.push(blockNode) + scope.pushNewMethodScope(fullName, name, blockNode, None) + localAstParentStack.push(blockNode) - val thisParam = - parameterInNode(astNodeInfo, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisParam, MethodScope) + val thisParam = + parameterInNode(astNodeInfo, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisParam, MethodScope) - val methodChildren = astsForFile(astNodeInfo) - setArgumentIndices(methodChildren) + val methodChildren = astsForFile(astNodeInfo) + setArgumentIndices(methodChildren) - val methodReturn = newMethodReturnNode(Defines.Any, line = None, column = None) + val methodReturn = newMethodReturnNode(Defines.Any, line = None, column = None) - localAstParentStack.pop() - scope.popScope() - methodAstParentStack.pop() + localAstParentStack.pop() + scope.popScope() + methodAstParentStack.pop() - functionTypeAndTypeDeclAst.withChild( - methodAst(programMethod, List(Ast(thisParam)), blockAst(blockNode, methodChildren), methodReturn) - ) - } + functionTypeAndTypeDeclAst.withChild( + methodAst( + programMethod, + List(Ast(thisParam)), + blockAst(blockNode, methodChildren), + methodReturn + ) + ) + end createProgramMethod - protected def astForNode(json: Value): Ast = { - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match { - case ClassDeclaration => astForClass(nodeInfo, shouldCreateAssignmentCall = true) - case DeclareClass => astForClass(nodeInfo, shouldCreateAssignmentCall = true) - case ClassExpression => astForClass(nodeInfo) - case TSInterfaceDeclaration => astForInterface(nodeInfo) - case TSModuleDeclaration => astForModule(nodeInfo) - case TSExportAssignment => astForExportAssignment(nodeInfo) - case ExportNamedDeclaration => astForExportNamedDeclaration(nodeInfo) - case ExportDefaultDeclaration => astForExportDefaultDeclaration(nodeInfo) - case ExportAllDeclaration => astForExportAllDeclaration(nodeInfo) - case ImportDeclaration => astForImportDeclaration(nodeInfo) - case FunctionDeclaration => astForFunctionDeclaration(nodeInfo) - case TSDeclareFunction => astForTSDeclareFunction(nodeInfo) - case VariableDeclaration => astForVariableDeclaration(nodeInfo) - case ArrowFunctionExpression => astForFunctionDeclaration(nodeInfo) - case FunctionExpression => astForFunctionDeclaration(nodeInfo) - case TSEnumDeclaration => astForEnum(nodeInfo) - case DeclareTypeAlias => astForTypeAlias(nodeInfo) - case TypeAlias => astForTypeAlias(nodeInfo) - case TypeCastExpression => astForCastExpression(nodeInfo) - case TSTypeAssertion => astForCastExpression(nodeInfo) - case TSTypeCastExpression => astForCastExpression(nodeInfo) - case TSTypeAliasDeclaration => astForTypeAlias(nodeInfo) - case NewExpression => astForNewExpression(nodeInfo) - case ThisExpression => astForThisExpression(nodeInfo) - case MemberExpression => astForMemberExpression(nodeInfo) - case OptionalMemberExpression => astForMemberExpression(nodeInfo) - case MetaProperty => astForMetaProperty(nodeInfo) - case CallExpression => astForCallExpression(nodeInfo) - case OptionalCallExpression => astForCallExpression(nodeInfo) - case SequenceExpression => astForSequenceExpression(nodeInfo) - case AssignmentExpression => astForAssignmentExpression(nodeInfo) - case AssignmentPattern => astForAssignmentExpression(nodeInfo) - case BinaryExpression => astForBinaryExpression(nodeInfo) - case LogicalExpression => astForLogicalExpression(nodeInfo) - case TSAsExpression => astForCastExpression(nodeInfo) - case UpdateExpression => astForUpdateExpression(nodeInfo) - case UnaryExpression => astForUnaryExpression(nodeInfo) - case ArrayExpression => astForArrayExpression(nodeInfo) - case AwaitExpression => astForAwaitExpression(nodeInfo) - case ConditionalExpression => astForConditionalExpression(nodeInfo) - case TaggedTemplateExpression => astForTemplateExpression(nodeInfo) - case ObjectExpression => astForObjectExpression(nodeInfo) - case TSNonNullExpression => astForTSNonNullExpression(nodeInfo) - case YieldExpression => astForReturnStatement(nodeInfo) - case ExpressionStatement => astForExpressionStatement(nodeInfo) - case IfStatement => astForIfStatement(nodeInfo) - case BlockStatement => astForBlockStatement(nodeInfo) - case ReturnStatement => astForReturnStatement(nodeInfo) - case TryStatement => astForTryStatement(nodeInfo) - case ForStatement => astForForStatement(nodeInfo) - case WhileStatement => astForWhileStatement(nodeInfo) - case DoWhileStatement => astForDoWhileStatement(nodeInfo) - case SwitchStatement => astForSwitchStatement(nodeInfo) - case BreakStatement => astForBreakStatement(nodeInfo) - case ContinueStatement => astForContinueStatement(nodeInfo) - case LabeledStatement => astForLabeledStatement(nodeInfo) - case ThrowStatement => astForThrowStatement(nodeInfo) - case ForInStatement => astForInOfStatement(nodeInfo) - case ForOfStatement => astForInOfStatement(nodeInfo) - case ObjectPattern => astForObjectExpression(nodeInfo) - case ArrayPattern => astForArrayExpression(nodeInfo) - case Identifier => astForIdentifier(nodeInfo) - case PrivateName => astForPrivateName(nodeInfo) - case Super => astForSuperKeyword(nodeInfo) - case Import => astForImportKeyword(nodeInfo) - case TSImportEqualsDeclaration => astForTSImportEqualsDeclaration(nodeInfo) - case StringLiteral => astForStringLiteral(nodeInfo) - case NumericLiteral => astForNumericLiteral(nodeInfo) - case NumberLiteral => astForNumberLiteral(nodeInfo) - case DecimalLiteral => astForDecimalLiteral(nodeInfo) - case NullLiteral => astForNullLiteral(nodeInfo) - case BooleanLiteral => astForBooleanLiteral(nodeInfo) - case RegExpLiteral => astForRegExpLiteral(nodeInfo) - case RegexLiteral => astForRegexLiteral(nodeInfo) - case BigIntLiteral => astForBigIntLiteral(nodeInfo) - case TemplateLiteral => astForTemplateLiteral(nodeInfo) - case TemplateElement => astForTemplateElement(nodeInfo) - case SpreadElement => astForSpreadOrRestElement(nodeInfo) - case TSSatisfiesExpression => astForTSSatisfiesExpression(nodeInfo) - case JSXElement => astForJsxElement(nodeInfo) - case JSXOpeningElement => astForJsxOpeningElement(nodeInfo) - case JSXClosingElement => astForJsxClosingElement(nodeInfo) - case JSXText => astForJsxText(nodeInfo) - case JSXExpressionContainer => astForJsxExprContainer(nodeInfo) - case JSXSpreadChild => astForJsxExprContainer(nodeInfo) - case JSXSpreadAttribute => astForJsxSpreadAttribute(nodeInfo) - case JSXFragment => astForJsxFragment(nodeInfo) - case JSXAttribute => astForJsxAttribute(nodeInfo) - case WithStatement => astForWithStatement(nodeInfo) - case EmptyStatement => Ast() - case DebuggerStatement => Ast() - case _ => notHandledYet(nodeInfo) - } - } + protected def astForNode(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case ClassDeclaration => astForClass(nodeInfo, shouldCreateAssignmentCall = true) + case DeclareClass => astForClass(nodeInfo, shouldCreateAssignmentCall = true) + case ClassExpression => astForClass(nodeInfo) + case TSInterfaceDeclaration => astForInterface(nodeInfo) + case TSModuleDeclaration => astForModule(nodeInfo) + case TSExportAssignment => astForExportAssignment(nodeInfo) + case ExportNamedDeclaration => astForExportNamedDeclaration(nodeInfo) + case ExportDefaultDeclaration => astForExportDefaultDeclaration(nodeInfo) + case ExportAllDeclaration => astForExportAllDeclaration(nodeInfo) + case ImportDeclaration => astForImportDeclaration(nodeInfo) + case FunctionDeclaration => astForFunctionDeclaration(nodeInfo) + case TSDeclareFunction => astForTSDeclareFunction(nodeInfo) + case VariableDeclaration => astForVariableDeclaration(nodeInfo) + case ArrowFunctionExpression => astForFunctionDeclaration(nodeInfo) + case FunctionExpression => astForFunctionDeclaration(nodeInfo) + case TSEnumDeclaration => astForEnum(nodeInfo) + case DeclareTypeAlias => astForTypeAlias(nodeInfo) + case TypeAlias => astForTypeAlias(nodeInfo) + case TypeCastExpression => astForCastExpression(nodeInfo) + case TSTypeAssertion => astForCastExpression(nodeInfo) + case TSTypeCastExpression => astForCastExpression(nodeInfo) + case TSTypeAliasDeclaration => astForTypeAlias(nodeInfo) + case NewExpression => astForNewExpression(nodeInfo) + case ThisExpression => astForThisExpression(nodeInfo) + case MemberExpression => astForMemberExpression(nodeInfo) + case OptionalMemberExpression => astForMemberExpression(nodeInfo) + case MetaProperty => astForMetaProperty(nodeInfo) + case CallExpression => astForCallExpression(nodeInfo) + case OptionalCallExpression => astForCallExpression(nodeInfo) + case SequenceExpression => astForSequenceExpression(nodeInfo) + case AssignmentExpression => astForAssignmentExpression(nodeInfo) + case AssignmentPattern => astForAssignmentExpression(nodeInfo) + case BinaryExpression => astForBinaryExpression(nodeInfo) + case LogicalExpression => astForLogicalExpression(nodeInfo) + case TSAsExpression => astForCastExpression(nodeInfo) + case UpdateExpression => astForUpdateExpression(nodeInfo) + case UnaryExpression => astForUnaryExpression(nodeInfo) + case ArrayExpression => astForArrayExpression(nodeInfo) + case AwaitExpression => astForAwaitExpression(nodeInfo) + case ConditionalExpression => astForConditionalExpression(nodeInfo) + case TaggedTemplateExpression => astForTemplateExpression(nodeInfo) + case ObjectExpression => astForObjectExpression(nodeInfo) + case TSNonNullExpression => astForTSNonNullExpression(nodeInfo) + case YieldExpression => astForReturnStatement(nodeInfo) + case ExpressionStatement => astForExpressionStatement(nodeInfo) + case IfStatement => astForIfStatement(nodeInfo) + case BlockStatement => astForBlockStatement(nodeInfo) + case ReturnStatement => astForReturnStatement(nodeInfo) + case TryStatement => astForTryStatement(nodeInfo) + case ForStatement => astForForStatement(nodeInfo) + case WhileStatement => astForWhileStatement(nodeInfo) + case DoWhileStatement => astForDoWhileStatement(nodeInfo) + case SwitchStatement => astForSwitchStatement(nodeInfo) + case BreakStatement => astForBreakStatement(nodeInfo) + case ContinueStatement => astForContinueStatement(nodeInfo) + case LabeledStatement => astForLabeledStatement(nodeInfo) + case ThrowStatement => astForThrowStatement(nodeInfo) + case ForInStatement => astForInOfStatement(nodeInfo) + case ForOfStatement => astForInOfStatement(nodeInfo) + case ObjectPattern => astForObjectExpression(nodeInfo) + case ArrayPattern => astForArrayExpression(nodeInfo) + case Identifier => astForIdentifier(nodeInfo) + case PrivateName => astForPrivateName(nodeInfo) + case Super => astForSuperKeyword(nodeInfo) + case Import => astForImportKeyword(nodeInfo) + case TSImportEqualsDeclaration => astForTSImportEqualsDeclaration(nodeInfo) + case StringLiteral => astForStringLiteral(nodeInfo) + case NumericLiteral => astForNumericLiteral(nodeInfo) + case NumberLiteral => astForNumberLiteral(nodeInfo) + case DecimalLiteral => astForDecimalLiteral(nodeInfo) + case NullLiteral => astForNullLiteral(nodeInfo) + case BooleanLiteral => astForBooleanLiteral(nodeInfo) + case RegExpLiteral => astForRegExpLiteral(nodeInfo) + case RegexLiteral => astForRegexLiteral(nodeInfo) + case BigIntLiteral => astForBigIntLiteral(nodeInfo) + case TemplateLiteral => astForTemplateLiteral(nodeInfo) + case TemplateElement => astForTemplateElement(nodeInfo) + case SpreadElement => astForSpreadOrRestElement(nodeInfo) + case TSSatisfiesExpression => astForTSSatisfiesExpression(nodeInfo) + case JSXElement => astForJsxElement(nodeInfo) + case JSXOpeningElement => astForJsxOpeningElement(nodeInfo) + case JSXClosingElement => astForJsxClosingElement(nodeInfo) + case JSXText => astForJsxText(nodeInfo) + case JSXExpressionContainer => astForJsxExprContainer(nodeInfo) + case JSXSpreadChild => astForJsxExprContainer(nodeInfo) + case JSXSpreadAttribute => astForJsxSpreadAttribute(nodeInfo) + case JSXFragment => astForJsxFragment(nodeInfo) + case JSXAttribute => astForJsxAttribute(nodeInfo) + case WithStatement => astForWithStatement(nodeInfo) + case EmptyStatement => Ast() + case DebuggerStatement => Ast() + case _ => notHandledYet(nodeInfo) + end match + end astForNode - protected def astForNodeWithFunctionReference(json: Value): Ast = { - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match { - case _: FunctionLike => astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true) - case _ => astForNode(json) - } - } + protected def astForNodeWithFunctionReference(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case _: FunctionLike => + astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true) + case _ => astForNode(json) - protected def astForNodeWithFunctionReferenceAndCall(json: Value): Ast = { - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match { - case _: FunctionLike => - astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true, shouldCreateAssignmentCall = true) - case _ => astForNode(json) - } - } + protected def astForNodeWithFunctionReferenceAndCall(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case _: FunctionLike => + astForFunctionDeclaration( + nodeInfo, + shouldCreateFunctionReference = true, + shouldCreateAssignmentCall = true + ) + case _ => astForNode(json) - protected def astForNodes(jsons: List[Value]): List[Ast] = jsons.map(astForNodeWithFunctionReference) + protected def astForNodes(jsons: List[Value]): List[Ast] = + jsons.map(astForNodeWithFunctionReference) - private def astsForFile(file: BabelNodeInfo): List[Ast] = astsForProgram(createBabelNodeInfo(file.json("program"))) + private def astsForFile(file: BabelNodeInfo): List[Ast] = + astsForProgram(createBabelNodeInfo(file.json("program"))) - private def astsForProgram(program: BabelNodeInfo): List[Ast] = createBlockStatementAsts(program.json("body")) + private def astsForProgram(program: BabelNodeInfo): List[Ast] = + createBlockStatementAsts(program.json("body")) - protected def line(node: BabelNodeInfo): Option[Integer] = node.lineNumber - protected def column(node: BabelNodeInfo): Option[Integer] = node.columnNumber - protected def lineEnd(node: BabelNodeInfo): Option[Integer] = node.lineNumberEnd - protected def columnEnd(node: BabelNodeInfo): Option[Integer] = node.columnNumberEnd -} + protected def line(node: BabelNodeInfo): Option[Integer] = node.lineNumber + protected def column(node: BabelNodeInfo): Option[Integer] = node.columnNumber + protected def lineEnd(node: BabelNodeInfo): Option[Integer] = node.lineNumberEnd + protected def columnEnd(node: BabelNodeInfo): Option[Integer] = node.columnNumberEnd +end AstCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala index 23b0d806..d7ae83ae 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala @@ -1,14 +1,14 @@ package io.appthreat.jssrc2cpg.astcreation import io.appthreat.jssrc2cpg.datastructures.{ - BlockScopeElement, - MethodScope, - MethodScopeElement, - ResolvedReference, - Scope, - ScopeElement, - ScopeElementIterator, - ScopeType + BlockScopeElement, + MethodScope, + MethodScopeElement, + ResolvedReference, + Scope, + ScopeElement, + ScopeElementIterator, + ScopeType } import io.appthreat.jssrc2cpg.parser.BabelNodeInfo import io.appthreat.jssrc2cpg.passes.Defines @@ -31,301 +31,299 @@ import scala.jdk.CollectionConverters.EnumerationHasAsScala import scala.util.Success import scala.util.Try -trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - // maximum length of code fields in number of characters - private val MaxCodeLength: Int = 1000 - private val MinCodeLength: Int = 50 - - protected def createBabelNodeInfo(json: Value): BabelNodeInfo = { - val c = shortenCode(code(json)) - val ln = line(json) - val cn = column(json) - val lnEnd = lineEnd(json) - val cnEnd = columnEnd(json) - val node = nodeType(json) - BabelNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd) - } - - protected def notHandledYet(node: BabelNodeInfo): Ast = { - val text = - s"""Node type '${node.node}' not handled yet! +trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + // maximum length of code fields in number of characters + private val MaxCodeLength: Int = 1000 + private val MinCodeLength: Int = 50 + + protected def createBabelNodeInfo(json: Value): BabelNodeInfo = + val c = shortenCode(code(json)) + val ln = line(json) + val cn = column(json) + val lnEnd = lineEnd(json) + val cnEnd = columnEnd(json) + val node = nodeType(json) + BabelNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd) + + protected def notHandledYet(node: BabelNodeInfo): Ast = + val text = + s"""Node type '${node.node}' not handled yet! | Code: '${shortenCode(node.code, length = 50)}' | File: '${parserResult.fullPath}' | Line: ${node.lineNumber.getOrElse(-1)} | Column: ${node.columnNumber.getOrElse(-1)} | """.stripMargin - logger.debug(text) - Ast(unknownNode(node, node.code)) - } - - protected def registerType(typeName: String, typeFullName: String): Unit = { - if (usedTypes.containsKey((typeName, typeName)) && typeName != typeFullName) { - usedTypes.put((typeName, typeFullName), true) - usedTypes.remove((typeName, typeName)) - } else if (!usedTypes.keys().asScala.exists { case (tpn, _) => tpn == typeName }) { - usedTypes.putIfAbsent((typeName, typeFullName), true) + logger.debug(text) + Ast(unknownNode(node, node.code)) + + protected def registerType(typeName: String, typeFullName: String): Unit = + if usedTypes.containsKey((typeName, typeName)) && typeName != typeFullName then + usedTypes.put((typeName, typeFullName), true) + usedTypes.remove((typeName, typeName)) + else if !usedTypes.keys().asScala.exists { case (tpn, _) => tpn == typeName } then + usedTypes.putIfAbsent((typeName, typeFullName), true) + + private def nodeType(node: Value): BabelNode = fromString(node("type").str) + + protected def codeForNodes(nodes: Seq[NewNode]): Option[String] = nodes.collectFirst { + case id: NewIdentifier => id.name.replace("...", "") + case clazz: NewTypeRef => clazz.code.stripPrefix("class ") } - } - - private def nodeType(node: Value): BabelNode = fromString(node("type").str) - - protected def codeForNodes(nodes: Seq[NewNode]): Option[String] = nodes.collectFirst { - case id: NewIdentifier => id.name.replace("...", "") - case clazz: NewTypeRef => clazz.code.stripPrefix("class ") - } - - protected def nameForBabelNodeInfo(nodeInfo: BabelNodeInfo, defaultName: Option[String]): String = { - defaultName - .orElse(codeForBabelNodeInfo(nodeInfo).headOption) - .getOrElse { - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localNode = newLocalNode(tmpName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - tmpName - } - } - - protected def generateUnusedVariableName( - usedVariableNames: mutable.HashMap[String, Int], - variableName: String - ): String = { - val counter = usedVariableNames.get(variableName).map(_ + 1).getOrElse(0) - val currentVariableName = s"${variableName}_$counter" - usedVariableNames.put(variableName, counter) - currentVariableName - } - - protected def code(node: Value): String = { - val startIndex = start(node).getOrElse(0) - val endIndex = Math.min(end(node).getOrElse(0), parserResult.fileContent.length) - parserResult.fileContent.substring(startIndex, endIndex).trim - } - - private def shortenCode(code: String, length: Int = MaxCodeLength): String = - StringUtils.abbreviate(code, math.max(MinCodeLength, length)) - - protected def hasKey(node: Value, key: String): Boolean = Try(node(key)).isSuccess - - protected def safeStr(node: Value, key: String): Option[String] = - if (hasKey(node, key)) Try(node(key).str).toOption else None - - protected def safeBool(node: Value, key: String): Option[Boolean] = - if (hasKey(node, key)) Try(node(key).bool).toOption else None - - protected def safeObj(node: Value, key: String): Option[upickle.core.LinkedHashMap[String, Value]] = Try( - node(key).obj - ) match { - case Success(value) if value.nonEmpty => Option(value) - case _ => None - } - - private def start(node: Value): Option[Int] = Try(node("start").num.toInt).toOption - - private def end(node: Value): Option[Int] = Try(node("end").num.toInt).toOption - - protected def pos(node: Value): Option[Int] = Try(node("start").num.toInt).toOption - - protected def line(node: Value): Option[Integer] = start(node).map(getLineOfSource) - - protected def lineEnd(node: Value): Option[Integer] = end(node).map(getLineOfSource) - - protected def column(node: Value): Option[Integer] = start(node).map(getColumnOfSource) - - protected def columnEnd(node: Value): Option[Integer] = end(node).map(getColumnOfSource) - - // Returns the line number for a given position in the source. - private def getLineOfSource(position: Int): Int = { - val (_, lineNumber) = positionToLineNumberMapping.minAfter(position).get - lineNumber - } - - // Returns the column number for a given position in the source. - private def getColumnOfSource(position: Int): Int = { - val (_, firstPositionInLine) = positionToFirstPositionInLineMapping.minAfter(position).get - position - firstPositionInLine - } - - protected def positionLookupTables(source: String): (SortedMap[Int, Int], SortedMap[Int, Int]) = { - val positionToLineNumber, positionToFirstPositionInLine = mutable.TreeMap.empty[Int, Int] - val data = source.toCharArray - var lineNumber = 1 - var firstPositionInLine = 0 - var position = 0 - while (position < data.length) { - val isNewLine = data(position) == '\n' - if (isNewLine) { + + protected def nameForBabelNodeInfo( + nodeInfo: BabelNodeInfo, + defaultName: Option[String] + ): String = + defaultName + .orElse(codeForBabelNodeInfo(nodeInfo).headOption) + .getOrElse { + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localNode = newLocalNode(tmpName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + tmpName + } + + protected def generateUnusedVariableName( + usedVariableNames: mutable.HashMap[String, Int], + variableName: String + ): String = + val counter = usedVariableNames.get(variableName).map(_ + 1).getOrElse(0) + val currentVariableName = s"${variableName}_$counter" + usedVariableNames.put(variableName, counter) + currentVariableName + + protected def code(node: Value): String = + val startIndex = start(node).getOrElse(0) + val endIndex = Math.min(end(node).getOrElse(0), parserResult.fileContent.length) + parserResult.fileContent.substring(startIndex, endIndex).trim + + private def shortenCode(code: String, length: Int = MaxCodeLength): String = + StringUtils.abbreviate(code, math.max(MinCodeLength, length)) + + protected def hasKey(node: Value, key: String): Boolean = Try(node(key)).isSuccess + + protected def safeStr(node: Value, key: String): Option[String] = + if hasKey(node, key) then Try(node(key).str).toOption else None + + protected def safeBool(node: Value, key: String): Option[Boolean] = + if hasKey(node, key) then Try(node(key).bool).toOption else None + + protected def safeObj( + node: Value, + key: String + ): Option[upickle.core.LinkedHashMap[String, Value]] = Try( + node(key).obj + ) match + case Success(value) if value.nonEmpty => Option(value) + case _ => None + + private def start(node: Value): Option[Int] = Try(node("start").num.toInt).toOption + + private def end(node: Value): Option[Int] = Try(node("end").num.toInt).toOption + + protected def pos(node: Value): Option[Int] = Try(node("start").num.toInt).toOption + + protected def line(node: Value): Option[Integer] = start(node).map(getLineOfSource) + + protected def lineEnd(node: Value): Option[Integer] = end(node).map(getLineOfSource) + + protected def column(node: Value): Option[Integer] = start(node).map(getColumnOfSource) + + protected def columnEnd(node: Value): Option[Integer] = end(node).map(getColumnOfSource) + + // Returns the line number for a given position in the source. + private def getLineOfSource(position: Int): Int = + val (_, lineNumber) = positionToLineNumberMapping.minAfter(position).get + lineNumber + + // Returns the column number for a given position in the source. + private def getColumnOfSource(position: Int): Int = + val (_, firstPositionInLine) = positionToFirstPositionInLineMapping.minAfter(position).get + position - firstPositionInLine + + protected def positionLookupTables(source: String): (SortedMap[Int, Int], SortedMap[Int, Int]) = + val positionToLineNumber, positionToFirstPositionInLine = mutable.TreeMap.empty[Int, Int] + val data = source.toCharArray + var lineNumber = 1 + var firstPositionInLine = 0 + var position = 0 + while position < data.length do + val isNewLine = data(position) == '\n' + if isNewLine then + positionToLineNumber.put(position, lineNumber) + lineNumber += 1 + positionToFirstPositionInLine.put(position, firstPositionInLine) + firstPositionInLine = position + 1 + position += 1 positionToLineNumber.put(position, lineNumber) - lineNumber += 1 positionToFirstPositionInLine.put(position, firstPositionInLine) - firstPositionInLine = position + 1 - } - position += 1 - } - positionToLineNumber.put(position, lineNumber) - positionToFirstPositionInLine.put(position, firstPositionInLine) - - // for empty line at the end of each JS/TS file generated by BabelJsonParser: - positionToLineNumber.put(position + 1, lineNumber + 1) - positionToFirstPositionInLine.put(position + 1, 0) - (positionToLineNumber, positionToFirstPositionInLine) - } - - private def computeScopePath(stack: Option[ScopeElement]): String = - new ScopeElementIterator(stack) - .to(Seq) - .reverse - .collect { case methodScopeElement: MethodScopeElement => methodScopeElement.name } - .mkString(":") - - private def isMethodOrGetSet(func: BabelNodeInfo): Boolean = { - if (hasKey(func.json, "kind") && !func.json("kind").isNull) { - val t = func.json("kind").str - t == "method" || t == "get" || t == "set" - } else false - } - - private def calcMethodName(func: BabelNodeInfo): String = func.node match { - case TSCallSignatureDeclaration => "anonymous" - case TSConstructSignatureDeclaration => io.appthreat.x2cpg.Defines.ConstructorMethodName - case _ if isMethodOrGetSet(func) => - if (hasKey(func.json("key"), "name")) func.json("key")("name").str - else code(func.json("key")) - case _ if safeStr(func.json, "kind").contains("constructor") => io.appthreat.x2cpg.Defines.ConstructorMethodName - case _ if func.json("id").isNull => "anonymous" - case _ => func.json("id")("name").str - } - - protected def calcMethodNameAndFullName(func: BabelNodeInfo): (String, String) = { - // functionNode.getName is not necessarily unique and thus the full name calculated based on the scope - // is not necessarily unique. Specifically we have this problem with lambda functions which are defined - // in the same scope. - functionNodeToNameAndFullName.get(func) match { - case Some(nameAndFullName) => nameAndFullName - case None => - val intendedName = calcMethodName(func) - val fullNamePrefix = s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" - var name = intendedName - var fullName = "" - var isUnique = false - var i = 1 - while (!isUnique) { - fullName = s"$fullNamePrefix$name" - if (functionFullNames.contains(fullName)) { - name = s"$intendedName$i" - i += 1 - } else { - isUnique = true - } - } - functionFullNames.add(fullName) - functionNodeToNameAndFullName(func) = (name, fullName) - (name, fullName) - } - } - - protected def stripQuotes(str: String): String = str - .stripPrefix("\"") - .stripSuffix("\"") - .stripPrefix("'") - .stripSuffix("'") - .stripPrefix("`") - .stripSuffix("`") - - /** In JS it is possible to create anonymous classes. We have to handle this here. - */ - private def calcTypeName(classNode: BabelNodeInfo): String = - if (hasKey(classNode.json, "id") && !classNode.json("id").isNull) code(classNode.json("id")) - else "_anon_cdecl" - - protected def calcTypeNameAndFullName( - classNode: BabelNodeInfo, - preCalculatedName: Option[String] = None - ): (String, String) = { - val name = preCalculatedName.getOrElse(calcTypeName(classNode)) - val fullNamePrefix = s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" - val intendedFullName = s"$fullNamePrefix$name" - val postfix = typeFullNameToPostfix.getOrElse(intendedFullName, 0) - val resultingFullName = - if (postfix == 0) intendedFullName - else s"$intendedFullName$postfix" - typeFullNameToPostfix.put(intendedFullName, postfix + 1) - (name, resultingFullName) - } - - protected def createVariableReferenceLinks(): Unit = { - val resolvedReferenceIt = scope.resolve(createMethodLocalForUnresolvedReference) - val capturedLocals = mutable.HashMap.empty[String, NewNode] - - resolvedReferenceIt.foreach { case ResolvedReference(variableNodeId, origin) => - var currentScope = origin.stack - var currentReference = origin.referenceNode - var nextReference: NewNode = null - - var done = false - while (!done) { - val localOrCapturedLocalNodeOption = - if (currentScope.get.nameToVariableNode.contains(origin.variableName)) { - done = true - Option(variableNodeId) - } else { - currentScope.flatMap { - case methodScope: MethodScopeElement - if methodScope.scopeNode.isInstanceOf[NewTypeDecl] || methodScope.scopeNode - .isInstanceOf[NewNamespaceBlock] => - currentScope = Option(Scope.getEnclosingMethodScopeElement(currentScope)) - None - case methodScope: MethodScopeElement => - // We have reached a MethodScope and still did not find a local variable to link to. - // For all non local references the CPG format does not allow us to link - // directly. Instead we need to create a fake local variable in method - // scope and link to this local which itself carries the information - // that it is a captured variable. This needs to be done for each - // method scope until we reach the originating scope. - val closureBindingIdProperty = s"${methodScope.methodFullName}:${origin.variableName}" - capturedLocals.updateWith(closureBindingIdProperty) { - case None => - val methodScopeNode = methodScope.scopeNode - val localNode = - newLocalNode(origin.variableName, Defines.Any, Option(closureBindingIdProperty)).order(0) - diffGraph.addEdge(methodScopeNode, localNode, EdgeTypes.AST) - val closureBindingNode = newClosureBindingNode( - closureBindingIdProperty, - origin.variableName, - EvaluationStrategies.BY_REFERENCE - ) - methodScope.capturingRefId.foreach(ref => - diffGraph.addEdge(ref, closureBindingNode, EdgeTypes.CAPTURE) - ) - nextReference = closureBindingNode - Option(localNode) - case someLocalNode => - // When there is already a LOCAL representing the capturing, we do not - // need to process the surrounding scope element as this has already - // been processed. - done = true - someLocalNode - } - case _: BlockScopeElement => None - } - } - localOrCapturedLocalNodeOption.foreach { localOrCapturedLocalNode => - diffGraph.addEdge(currentReference, localOrCapturedLocalNode, EdgeTypes.REF) - currentReference = nextReference + // for empty line at the end of each JS/TS file generated by BabelJsonParser: + positionToLineNumber.put(position + 1, lineNumber + 1) + positionToFirstPositionInLine.put(position + 1, 0) + (positionToLineNumber, positionToFirstPositionInLine) + end positionLookupTables + + private def computeScopePath(stack: Option[ScopeElement]): String = + new ScopeElementIterator(stack) + .to(Seq) + .reverse + .collect { case methodScopeElement: MethodScopeElement => methodScopeElement.name } + .mkString(":") + + private def isMethodOrGetSet(func: BabelNodeInfo): Boolean = + if hasKey(func.json, "kind") && !func.json("kind").isNull then + val t = func.json("kind").str + t == "method" || t == "get" || t == "set" + else false + + private def calcMethodName(func: BabelNodeInfo): String = func.node match + case TSCallSignatureDeclaration => "anonymous" + case TSConstructSignatureDeclaration => io.appthreat.x2cpg.Defines.ConstructorMethodName + case _ if isMethodOrGetSet(func) => + if hasKey(func.json("key"), "name") then func.json("key")("name").str + else code(func.json("key")) + case _ if safeStr(func.json, "kind").contains("constructor") => + io.appthreat.x2cpg.Defines.ConstructorMethodName + case _ if func.json("id").isNull => "anonymous" + case _ => func.json("id")("name").str + + protected def calcMethodNameAndFullName(func: BabelNodeInfo): (String, String) = + // functionNode.getName is not necessarily unique and thus the full name calculated based on the scope + // is not necessarily unique. Specifically we have this problem with lambda functions which are defined + // in the same scope. + functionNodeToNameAndFullName.get(func) match + case Some(nameAndFullName) => nameAndFullName + case None => + val intendedName = calcMethodName(func) + val fullNamePrefix = + s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" + var name = intendedName + var fullName = "" + var isUnique = false + var i = 1 + while !isUnique do + fullName = s"$fullNamePrefix$name" + if functionFullNames.contains(fullName) then + name = s"$intendedName$i" + i += 1 + else + isUnique = true + functionFullNames.add(fullName) + functionNodeToNameAndFullName(func) = (name, fullName) + (name, fullName) + + protected def stripQuotes(str: String): String = str + .stripPrefix("\"") + .stripSuffix("\"") + .stripPrefix("'") + .stripSuffix("'") + .stripPrefix("`") + .stripSuffix("`") + + /** In JS it is possible to create anonymous classes. We have to handle this here. + */ + private def calcTypeName(classNode: BabelNodeInfo): String = + if hasKey(classNode.json, "id") && !classNode.json("id").isNull then + code(classNode.json("id")) + else "_anon_cdecl" + + protected def calcTypeNameAndFullName( + classNode: BabelNodeInfo, + preCalculatedName: Option[String] = None + ): (String, String) = + val name = preCalculatedName.getOrElse(calcTypeName(classNode)) + val fullNamePrefix = s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" + val intendedFullName = s"$fullNamePrefix$name" + val postfix = typeFullNameToPostfix.getOrElse(intendedFullName, 0) + val resultingFullName = + if postfix == 0 then intendedFullName + else s"$intendedFullName$postfix" + typeFullNameToPostfix.put(intendedFullName, postfix + 1) + (name, resultingFullName) + + protected def createVariableReferenceLinks(): Unit = + val resolvedReferenceIt = scope.resolve(createMethodLocalForUnresolvedReference) + val capturedLocals = mutable.HashMap.empty[String, NewNode] + + resolvedReferenceIt.foreach { case ResolvedReference(variableNodeId, origin) => + var currentScope = origin.stack + var currentReference = origin.referenceNode + var nextReference: NewNode = null + + var done = false + while !done do + val localOrCapturedLocalNodeOption = + if currentScope.get.nameToVariableNode.contains(origin.variableName) then + done = true + Option(variableNodeId) + else + currentScope.flatMap { + case methodScope: MethodScopeElement + if methodScope.scopeNode.isInstanceOf[NewTypeDecl] || methodScope.scopeNode + .isInstanceOf[NewNamespaceBlock] => + currentScope = + Option(Scope.getEnclosingMethodScopeElement(currentScope)) + None + case methodScope: MethodScopeElement => + // We have reached a MethodScope and still did not find a local variable to link to. + // For all non local references the CPG format does not allow us to link + // directly. Instead we need to create a fake local variable in method + // scope and link to this local which itself carries the information + // that it is a captured variable. This needs to be done for each + // method scope until we reach the originating scope. + val closureBindingIdProperty = + s"${methodScope.methodFullName}:${origin.variableName}" + capturedLocals.updateWith(closureBindingIdProperty) { + case None => + val methodScopeNode = methodScope.scopeNode + val localNode = + newLocalNode( + origin.variableName, + Defines.Any, + Option(closureBindingIdProperty) + ).order(0) + diffGraph.addEdge(methodScopeNode, localNode, EdgeTypes.AST) + val closureBindingNode = newClosureBindingNode( + closureBindingIdProperty, + origin.variableName, + EvaluationStrategies.BY_REFERENCE + ) + methodScope.capturingRefId.foreach(ref => + diffGraph.addEdge( + ref, + closureBindingNode, + EdgeTypes.CAPTURE + ) + ) + nextReference = closureBindingNode + Option(localNode) + case someLocalNode => + // When there is already a LOCAL representing the capturing, we do not + // need to process the surrounding scope element as this has already + // been processed. + done = true + someLocalNode + } + case _: BlockScopeElement => None + } + + localOrCapturedLocalNodeOption.foreach { localOrCapturedLocalNode => + diffGraph.addEdge(currentReference, localOrCapturedLocalNode, EdgeTypes.REF) + currentReference = nextReference + } + currentScope = currentScope.get.surroundingScope + end while } - currentScope = currentScope.get.surroundingScope - } - } - } - - private def createMethodLocalForUnresolvedReference( - methodScopeNodeId: NewNode, - variableName: String - ): (NewNode, ScopeType) = { - val local = newLocalNode(variableName, Defines.Any).order(0) - diffGraph.addEdge(methodScopeNodeId, local, EdgeTypes.AST) - (local, MethodScope) - } - -} + end createVariableReferenceLinks + + private def createMethodLocalForUnresolvedReference( + methodScopeNodeId: NewNode, + variableName: String + ): (NewNode, ScopeType) = + val local = newLocalNode(variableName, Defines.Any).order(0) + diffGraph.addEdge(methodScopeNodeId, local, EdgeTypes.AST) + (local, MethodScope) +end AstCreatorHelper diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala index 7f19a316..e2b390a6 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala @@ -13,734 +13,849 @@ import ujson.Value import scala.util.Try -trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - private val DefaultsKey = "default" - private val ExportKeyword = "exports" - private val RequireKeyword = "require" - private val ImportKeyword = "import" - - private def hasName(json: Value): Boolean = hasKey(json, "id") && !json("id").isNull - - protected def codeForBabelNodeInfo(obj: BabelNodeInfo): Seq[String] = { - val codes = obj.node match { - case Identifier => Seq(obj.code) - case NumericLiteral => Seq(obj.code) - case StringLiteral => Seq(obj.code) - case AssignmentExpression => Seq(code(obj.json("left"))) - case ClassDeclaration => Seq(code(obj.json("id"))) - case TSTypeAliasDeclaration => Seq(code(obj.json("id"))) - case TSInterfaceDeclaration => Seq(code(obj.json("id"))) - case TSEnumDeclaration => Seq(code(obj.json("id"))) - case TSModuleDeclaration => Seq(code(obj.json("id"))) - case TSDeclareFunction if hasName(obj.json) => Seq(obj.json("id")("name").str) - case FunctionDeclaration if hasName(obj.json) => Seq(obj.json("id")("name").str) - case FunctionExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) - case ClassExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) - case VariableDeclarator if hasName(obj.json) => - createBabelNodeInfo(obj.json("id")).node match { - case ArrayPattern => obj.json("id")("elements").arr.toSeq.map(createBabelNodeInfo).map(_.code) - case _ => Seq(obj.json("id")("name").str) +trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + private val DefaultsKey = "default" + private val ExportKeyword = "exports" + private val RequireKeyword = "require" + private val ImportKeyword = "import" + + private def hasName(json: Value): Boolean = hasKey(json, "id") && !json("id").isNull + + protected def codeForBabelNodeInfo(obj: BabelNodeInfo): Seq[String] = + val codes = obj.node match + case Identifier => Seq(obj.code) + case NumericLiteral => Seq(obj.code) + case StringLiteral => Seq(obj.code) + case AssignmentExpression => Seq(code(obj.json("left"))) + case ClassDeclaration => Seq(code(obj.json("id"))) + case TSTypeAliasDeclaration => Seq(code(obj.json("id"))) + case TSInterfaceDeclaration => Seq(code(obj.json("id"))) + case TSEnumDeclaration => Seq(code(obj.json("id"))) + case TSModuleDeclaration => Seq(code(obj.json("id"))) + case TSDeclareFunction if hasName(obj.json) => Seq(obj.json("id")("name").str) + case FunctionDeclaration if hasName(obj.json) => Seq(obj.json("id")("name").str) + case FunctionExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) + case ClassExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) + case VariableDeclarator if hasName(obj.json) => + createBabelNodeInfo(obj.json("id")).node match + case ArrayPattern => + obj.json("id")("elements").arr.toSeq.map(createBabelNodeInfo).map(_.code) + case _ => Seq(obj.json("id")("name").str) + case VariableDeclarator => Seq(code(obj.json("id"))) + case MemberExpression => Seq(code(obj.json("property"))) + case ObjectExpression => + obj.json("properties").arr.toSeq.flatMap(d => + codeForBabelNodeInfo(createBabelNodeInfo(d)) + ) + case VariableDeclaration => + obj.json("declarations").arr.toSeq.flatMap(d => + codeForBabelNodeInfo(createBabelNodeInfo(d)) + ) + case _ => Seq.empty + codes.map(_.replace("...", "")) + end codeForBabelNodeInfo + + private def createExportCallAst( + name: String, + exportName: String, + declaration: BabelNodeInfo + ): Ast = + val exportCallAst = if name == DefaultsKey then + createIndexAccessCallAst( + identifierNode(declaration, exportName), + literalNode(declaration, s"\"$DefaultsKey\"", Option(Defines.String)), + declaration.lineNumber, + declaration.columnNumber + ) + else + createFieldAccessCallAst( + identifierNode(declaration, exportName), + createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), + declaration.lineNumber, + declaration.columnNumber + ) + exportCallAst + end createExportCallAst + + private def createExportAssignmentCallAst( + name: String, + exportCallAst: Ast, + declaration: BabelNodeInfo, + from: Option[String] + ): Ast = + from match + case Some(value) => + val call = createFieldAccessCallAst( + identifierNode(declaration, value, Seq.empty), + createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), + declaration.lineNumber, + declaration.columnNumber + ) + createAssignmentCallAst( + exportCallAst, + call, + s"${codeOf(exportCallAst.nodes.head)} = ${codeOf(call.nodes.head)}", + declaration.lineNumber, + declaration.columnNumber + ) + case None => + createAssignmentCallAst( + exportCallAst, + Ast(identifierNode(declaration, name)), + s"${codeOf(exportCallAst.nodes.head)} = $name", + declaration.lineNumber, + declaration.columnNumber + ) + + private def extractDeclarationsFromExportDecl( + declaration: BabelNodeInfo, + key: String + ): Option[(Ast, Seq[String])] = + safeObj(declaration.json, key) + .map { d => + val nodeInfo = createBabelNodeInfo(d) + val ast = astForNodeWithFunctionReferenceAndCall(d) + val defaultName = codeForNodes(ast.nodes.toSeq) + val codes = codeForBabelNodeInfo(nodeInfo) + val names = if codes.isEmpty then defaultName.toSeq else codes + (ast, names) + } + + private def extractExportFromNameFromExportDecl(declaration: BabelNodeInfo): String = + safeObj(declaration.json, "source") + .map(d => s"_${stripQuotes(code(d))}") + .getOrElse(ExportKeyword) + + private def cleanImportName(name: String): String = if name.contains("/") then + val stripped = name.stripSuffix("/") + stripped.substring(stripped.lastIndexOf("/") + 1) + else name + + private def createAstForFrom(fromName: String, declaration: BabelNodeInfo): Ast = + if fromName == ExportKeyword then + Ast() + else + val strippedCode = cleanImportName(fromName).stripPrefix("_") + val id = identifierNode(declaration, s"_$strippedCode") + val localNode = newLocalNode(id.code, Defines.Any).order(0) + scope.addVariable(id.code, localNode, BlockScope) + scope.addVariableReference(id.code, id) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + val sourceCallArgNode = + literalNode(declaration, s"\"${fromName.stripPrefix("_")}\"", None) + val sourceCall = + callNode( + declaration, + s"$RequireKeyword(${sourceCallArgNode.code})", + RequireKeyword, + DispatchTypes.STATIC_DISPATCH + ) + val sourceAst = + callAst(sourceCall, List(Ast(sourceCallArgNode))) + val assignmentCallAst = createAssignmentCallAst( + Ast(id), + sourceAst, + s"var ${codeOf(id)} = ${codeOf(sourceAst.nodes.head)}", + declaration.lineNumber, + declaration.columnNumber + ) + assignmentCallAst + + protected def astsForDecorators(elem: BabelNodeInfo): Seq[Ast] = + if hasKey(elem.json, "decorators") && !elem.json("decorators").isNull then + elem.json("decorators").arr.toList.map(d => astForDecorator(createBabelNodeInfo(d))) + else Seq.empty + + private def namesForDecoratorExpression(code: String): (String, String) = + val dotLastIndex = code.lastIndexOf(".") + if dotLastIndex != -1 then + (code.substring(dotLastIndex + 1), code) + else + (code, code) + + private def astForDecorator(decorator: BabelNodeInfo): Ast = + val exprNode = createBabelNodeInfo(decorator.json("expression")) + exprNode.node match + case Identifier | MemberExpression => + val (name, fullName) = namesForDecoratorExpression(code(exprNode.json)) + annotationAst(annotationNode(decorator, decorator.code, name, fullName), List.empty) + case CallExpression => + val (name, fullName) = namesForDecoratorExpression(code(exprNode.json("callee"))) + val node = annotationNode(decorator, decorator.code, name, fullName) + val assignmentAsts = exprNode.json("arguments").arr.toList.map { arg => + createBabelNodeInfo(arg).node match + case AssignmentExpression => + annotationAssignmentAst( + code(arg("left")), + code(arg), + astForNodeWithFunctionReference(arg("right")) + ) + case _ => + annotationAssignmentAst( + "value", + code(arg), + astForNodeWithFunctionReference(arg) + ) + } + annotationAst(node, assignmentAsts) + case _ => Ast() + end match + end astForDecorator + + protected def astForExportNamedDeclaration(declaration: BabelNodeInfo): Ast = + val specifiers = declaration + .json("specifiers") + .arr + .toList + .map { spec => + if createBabelNodeInfo(spec).node == ExportNamespaceSpecifier then + val exported = createBabelNodeInfo(spec("exported")) + (None, Option(exported)) + else + val exported = createBabelNodeInfo(spec("exported")) + val local = if hasKey(spec, "local") then + createBabelNodeInfo(spec("local")) + else + exported + (Option(local), Option(exported)) + } + + val exportName = extractExportFromNameFromExportDecl(declaration) + val fromAst = createAstForFrom(exportName, declaration) + val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") + val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode( + name, + exportName.stripPrefix("_"), + RequireKeyword + )) + val exportCallAst = createExportCallAst(name, exportName, declaration) + createExportAssignmentCallAst(name, exportCallAst, declaration, None) + } } - case VariableDeclarator => Seq(code(obj.json("id"))) - case MemberExpression => Seq(code(obj.json("property"))) - case ObjectExpression => - obj.json("properties").arr.toSeq.flatMap(d => codeForBabelNodeInfo(createBabelNodeInfo(d))) - case VariableDeclaration => - obj.json("declarations").arr.toSeq.flatMap(d => codeForBabelNodeInfo(createBabelNodeInfo(d))) - case _ => Seq.empty - } - codes.map(_.replace("...", "")) - } - - private def createExportCallAst(name: String, exportName: String, declaration: BabelNodeInfo): Ast = { - val exportCallAst = if (name == DefaultsKey) { - createIndexAccessCallAst( - identifierNode(declaration, exportName), - literalNode(declaration, s"\"$DefaultsKey\"", Option(Defines.String)), - declaration.lineNumber, - declaration.columnNumber - ) - } else { - createFieldAccessCallAst( - identifierNode(declaration, exportName), - createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), - declaration.lineNumber, - declaration.columnNumber - ) - } - exportCallAst - } - - private def createExportAssignmentCallAst( - name: String, - exportCallAst: Ast, - declaration: BabelNodeInfo, - from: Option[String] - ): Ast = { - from match { - case Some(value) => - val call = createFieldAccessCallAst( - identifierNode(declaration, value, Seq.empty), - createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), - declaration.lineNumber, - declaration.columnNumber - ) - createAssignmentCallAst( - exportCallAst, - call, - s"${codeOf(exportCallAst.nodes.head)} = ${codeOf(call.nodes.head)}", - declaration.lineNumber, - declaration.columnNumber - ) - case None => - createAssignmentCallAst( - exportCallAst, - Ast(identifierNode(declaration, name)), - s"${codeOf(exportCallAst.nodes.head)} = $name", - declaration.lineNumber, - declaration.columnNumber - ) - } - - } - - private def extractDeclarationsFromExportDecl(declaration: BabelNodeInfo, key: String): Option[(Ast, Seq[String])] = - safeObj(declaration.json, key) - .map { d => - val nodeInfo = createBabelNodeInfo(d) - val ast = astForNodeWithFunctionReferenceAndCall(d) - val defaultName = codeForNodes(ast.nodes.toSeq) - val codes = codeForBabelNodeInfo(nodeInfo) - val names = if (codes.isEmpty) defaultName.toSeq else codes - (ast, names) - } - - private def extractExportFromNameFromExportDecl(declaration: BabelNodeInfo): String = - safeObj(declaration.json, "source") - .map(d => s"_${stripQuotes(code(d))}") - .getOrElse(ExportKeyword) - - private def cleanImportName(name: String): String = if (name.contains("/")) { - val stripped = name.stripSuffix("/") - stripped.substring(stripped.lastIndexOf("/") + 1) - } else name - - private def createAstForFrom(fromName: String, declaration: BabelNodeInfo): Ast = { - if (fromName == ExportKeyword) { - Ast() - } else { - val strippedCode = cleanImportName(fromName).stripPrefix("_") - val id = identifierNode(declaration, s"_$strippedCode") - val localNode = newLocalNode(id.code, Defines.Any).order(0) - scope.addVariable(id.code, localNode, BlockScope) - scope.addVariableReference(id.code, id) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - val sourceCallArgNode = literalNode(declaration, s"\"${fromName.stripPrefix("_")}\"", None) - val sourceCall = - callNode( - declaration, - s"$RequireKeyword(${sourceCallArgNode.code})", - RequireKeyword, - DispatchTypes.STATIC_DISPATCH - ) - val sourceAst = - callAst(sourceCall, List(Ast(sourceCallArgNode))) - val assignmentCallAst = createAssignmentCallAst( - Ast(id), - sourceAst, - s"var ${codeOf(id)} = ${codeOf(sourceAst.nodes.head)}", - declaration.lineNumber, - declaration.columnNumber - ) - assignmentCallAst - } - } - - protected def astsForDecorators(elem: BabelNodeInfo): Seq[Ast] = { - if (hasKey(elem.json, "decorators") && !elem.json("decorators").isNull) { - elem.json("decorators").arr.toList.map(d => astForDecorator(createBabelNodeInfo(d))) - } else Seq.empty - } - - private def namesForDecoratorExpression(code: String): (String, String) = { - val dotLastIndex = code.lastIndexOf(".") - if (dotLastIndex != -1) { - (code.substring(dotLastIndex + 1), code) - } else { - (code, code) - } - } - - private def astForDecorator(decorator: BabelNodeInfo): Ast = { - val exprNode = createBabelNodeInfo(decorator.json("expression")) - exprNode.node match { - case Identifier | MemberExpression => - val (name, fullName) = namesForDecoratorExpression(code(exprNode.json)) - annotationAst(annotationNode(decorator, decorator.code, name, fullName), List.empty) - case CallExpression => - val (name, fullName) = namesForDecoratorExpression(code(exprNode.json("callee"))) - val node = annotationNode(decorator, decorator.code, name, fullName) - val assignmentAsts = exprNode.json("arguments").arr.toList.map { arg => - createBabelNodeInfo(arg).node match { - case AssignmentExpression => - annotationAssignmentAst(code(arg("left")), code(arg), astForNodeWithFunctionReference(arg("right"))) - case _ => - annotationAssignmentAst("value", code(arg), astForNodeWithFunctionReference(arg)) - } + + val specifierAsts = specifiers.map { + case (Some(name), Some(alias)) => + val strippedCode = cleanImportName(exportName).stripPrefix("_") + val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode( + alias.code, + exportName.stripPrefix("_"), + RequireKeyword + )) + createExportAssignmentCallAst( + name.code, + exportCallAst, + declaration, + Option(s"_$strippedCode") + ) + else + createExportAssignmentCallAst(name.code, exportCallAst, declaration, None) + case (None, Some(alias)) => + diffGraph.addNode(newDependencyNode( + alias.code, + exportName.stripPrefix("_"), + RequireKeyword + )) + val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) + createExportAssignmentCallAst(exportName, exportCallAst, declaration, None) + case _ => Ast() + } + + val asts = fromAst +: (specifierAsts ++ declAsts) + setArgumentIndices(asts) + blockAst(createBlockNode(declaration), asts) + end astForExportNamedDeclaration + + protected def astForExportAssignment(assignment: BabelNodeInfo): Ast = + val expressionAstWithNames = extractDeclarationsFromExportDecl(assignment, "expression") + val declAsts = expressionAstWithNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + val exportCallAst = createExportCallAst(name, ExportKeyword, assignment) + createExportAssignmentCallAst(name, exportCallAst, assignment, None) + } } - annotationAst(node, assignmentAsts) - case _ => Ast() - } - } - - protected def astForExportNamedDeclaration(declaration: BabelNodeInfo): Ast = { - val specifiers = declaration - .json("specifiers") - .arr - .toList - .map { spec => - if (createBabelNodeInfo(spec).node == ExportNamespaceSpecifier) { - val exported = createBabelNodeInfo(spec("exported")) - (None, Option(exported)) - } else { - val exported = createBabelNodeInfo(spec("exported")) - val local = if (hasKey(spec, "local")) { - createBabelNodeInfo(spec("local")) - } else { - exported - } - (Option(local), Option(exported)) + + setArgumentIndices(declAsts) + blockAst(createBlockNode(assignment), declAsts) + + protected def astForExportDefaultDeclaration(declaration: BabelNodeInfo): Ast = + val exportName = extractExportFromNameFromExportDecl(declaration) + val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") + val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + val exportCallAst = createExportCallAst(DefaultsKey, exportName, declaration) + createExportAssignmentCallAst(name, exportCallAst, declaration, None) + } } - } - - val exportName = extractExportFromNameFromExportDecl(declaration) - val fromAst = createAstForFrom(exportName, declaration) - val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") - val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - if (exportName != ExportKeyword) - diffGraph.addNode(newDependencyNode(name, exportName.stripPrefix("_"), RequireKeyword)) - val exportCallAst = createExportCallAst(name, exportName, declaration) - createExportAssignmentCallAst(name, exportCallAst, declaration, None) - } - } - - val specifierAsts = specifiers.map { - case (Some(name), Some(alias)) => - val strippedCode = cleanImportName(exportName).stripPrefix("_") - val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) - if (exportName != ExportKeyword) { - diffGraph.addNode(newDependencyNode(alias.code, exportName.stripPrefix("_"), RequireKeyword)) - createExportAssignmentCallAst(name.code, exportCallAst, declaration, Option(s"_$strippedCode")) - } else { - createExportAssignmentCallAst(name.code, exportCallAst, declaration, None) + setArgumentIndices(declAsts) + blockAst(createBlockNode(declaration), declAsts) + + protected def astForExportAllDeclaration(declaration: BabelNodeInfo): Ast = + val exportName = extractExportFromNameFromExportDecl(declaration) + val depGroupId = stripQuotes(code(declaration.json("source"))) + val name = cleanImportName(depGroupId) + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode(name, depGroupId, RequireKeyword)) + + val fromCallAst = createAstForFrom(exportName, declaration) + val exportCallAst = createExportCallAst(name, ExportKeyword, declaration) + val assignmentCallAst = + createExportAssignmentCallAst(s"_$name", exportCallAst, declaration, None) + + val childrenAsts = List(fromCallAst, assignmentCallAst) + setArgumentIndices(childrenAsts) + blockAst(createBlockNode(declaration), childrenAsts) + + protected def astForVariableDeclaration(declaration: BabelNodeInfo): Ast = + val kind = declaration.json("kind").str + val scopeType = if kind == "let" then + BlockScope + else + MethodScope + val declAsts = declaration.json("declarations").arr.toList.map(astForVariableDeclarator( + _, + scopeType, + kind + )) + declAsts match + case Nil => Ast() + case head :: Nil => head + case _ => blockAst(createBlockNode(declaration), declAsts) + + private def handleRequireCallForDependencies( + declarator: BabelNodeInfo, + lhs: Value, + rhs: Value, + call: Option[NewCall] + ): Unit = + val rhsCode = code(rhs) + val groupId = + rhsCode.substring(rhsCode.indexOf(s"$RequireKeyword(") + 9, rhsCode.indexOf(")") - 1) + val nodeInfo = createBabelNodeInfo(lhs) + val names = nodeInfo.node match + case ArrayPattern => nodeInfo.json("elements").arr.toList.map(code) + case ObjectPattern => nodeInfo.json("properties").arr.toList.map(code) + case _ => List(code(lhs)) + names.foreach { name => + val _dependencyNode = newDependencyNode(name, groupId, RequireKeyword) + diffGraph.addNode(_dependencyNode) + val importNode = createImportNodeAndAttachToCall(declarator, groupId, name, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) } - case (None, Some(alias)) => - diffGraph.addNode(newDependencyNode(alias.code, exportName.stripPrefix("_"), RequireKeyword)) - val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) - createExportAssignmentCallAst(exportName, exportCallAst, declaration, None) - case _ => Ast() - } - - val asts = fromAst +: (specifierAsts ++ declAsts) - setArgumentIndices(asts) - blockAst(createBlockNode(declaration), asts) - } - - protected def astForExportAssignment(assignment: BabelNodeInfo): Ast = { - val expressionAstWithNames = extractDeclarationsFromExportDecl(assignment, "expression") - val declAsts = expressionAstWithNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - val exportCallAst = createExportCallAst(name, ExportKeyword, assignment) - createExportAssignmentCallAst(name, exportCallAst, assignment, None) - } - } - - setArgumentIndices(declAsts) - blockAst(createBlockNode(assignment), declAsts) - } - - protected def astForExportDefaultDeclaration(declaration: BabelNodeInfo): Ast = { - val exportName = extractExportFromNameFromExportDecl(declaration) - val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") - val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - val exportCallAst = createExportCallAst(DefaultsKey, exportName, declaration) - createExportAssignmentCallAst(name, exportCallAst, declaration, None) - } - } - setArgumentIndices(declAsts) - blockAst(createBlockNode(declaration), declAsts) - } - - protected def astForExportAllDeclaration(declaration: BabelNodeInfo): Ast = { - val exportName = extractExportFromNameFromExportDecl(declaration) - val depGroupId = stripQuotes(code(declaration.json("source"))) - val name = cleanImportName(depGroupId) - if (exportName != ExportKeyword) { - diffGraph.addNode(newDependencyNode(name, depGroupId, RequireKeyword)) - } - - val fromCallAst = createAstForFrom(exportName, declaration) - val exportCallAst = createExportCallAst(name, ExportKeyword, declaration) - val assignmentCallAst = createExportAssignmentCallAst(s"_$name", exportCallAst, declaration, None) - - val childrenAsts = List(fromCallAst, assignmentCallAst) - setArgumentIndices(childrenAsts) - blockAst(createBlockNode(declaration), childrenAsts) - } - - protected def astForVariableDeclaration(declaration: BabelNodeInfo): Ast = { - val kind = declaration.json("kind").str - val scopeType = if (kind == "let") { - BlockScope - } else { - MethodScope - } - val declAsts = declaration.json("declarations").arr.toList.map(astForVariableDeclarator(_, scopeType, kind)) - declAsts match { - case Nil => Ast() - case head :: Nil => head - case _ => blockAst(createBlockNode(declaration), declAsts) - } - } - - private def handleRequireCallForDependencies( - declarator: BabelNodeInfo, - lhs: Value, - rhs: Value, - call: Option[NewCall] - ): Unit = { - val rhsCode = code(rhs) - val groupId = rhsCode.substring(rhsCode.indexOf(s"$RequireKeyword(") + 9, rhsCode.indexOf(")") - 1) - val nodeInfo = createBabelNodeInfo(lhs) - val names = nodeInfo.node match { - case ArrayPattern => nodeInfo.json("elements").arr.toList.map(code) - case ObjectPattern => nodeInfo.json("properties").arr.toList.map(code) - case _ => List(code(lhs)) - } - names.foreach { name => - val _dependencyNode = newDependencyNode(name, groupId, RequireKeyword) - diffGraph.addNode(_dependencyNode) - val importNode = createImportNodeAndAttachToCall(declarator, groupId, name, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - } - } - - private def astForVariableDeclarator(declarator: Value, scopeType: ScopeType, kind: String): Ast = { - val idNodeInfo = createBabelNodeInfo(declarator("id")) - val declNodeInfo = createBabelNodeInfo(declarator) - val initNodeInfo = Try(createBabelNodeInfo(declarator("init"))).toOption - val declaratorCode = s"$kind ${code(declarator)}" - val typeFullName = typeFor(declNodeInfo) - - val idName = idNodeInfo.node match { - case Identifier => idNodeInfo.json("name").str - case _ => idNodeInfo.code - } - val localNode = newLocalNode(idName, typeFullName).order(0) - scope.addVariable(idName, localNode, scopeType) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - if (initNodeInfo.isEmpty) { - Ast() - } else { - val sourceAst = initNodeInfo.get match { - case requireCall if requireCall.code.startsWith(s"$RequireKeyword(") => - val call = astForNodeWithFunctionReference(requireCall.json) - handleRequireCallForDependencies( - createBabelNodeInfo(declarator), - idNodeInfo.json, - initNodeInfo.get.json, - call.root.map(_.asInstanceOf[NewCall]) - ) - call - case initExpr => - astForNodeWithFunctionReference(initExpr.json) - } - val nodeInfo = createBabelNodeInfo(idNodeInfo.json) - nodeInfo.node match { - case ObjectPattern | ArrayPattern => - astForDeconstruction(nodeInfo, sourceAst, declaratorCode) - case _ => - val destAst = idNodeInfo.node match { - case Identifier => astForIdentifier(idNodeInfo, Option(typeFullName)) - case _ => astForNode(idNodeInfo.json) - } - - val assignmentCallAst = + end handleRequireCallForDependencies + + private def astForVariableDeclarator( + declarator: Value, + scopeType: ScopeType, + kind: String + ): Ast = + val idNodeInfo = createBabelNodeInfo(declarator("id")) + val declNodeInfo = createBabelNodeInfo(declarator) + val initNodeInfo = Try(createBabelNodeInfo(declarator("init"))).toOption + val declaratorCode = s"$kind ${code(declarator)}" + val typeFullName = typeFor(declNodeInfo) + + val idName = idNodeInfo.node match + case Identifier => idNodeInfo.json("name").str + case _ => idNodeInfo.code + val localNode = newLocalNode(idName, typeFullName).order(0) + scope.addVariable(idName, localNode, scopeType) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + if initNodeInfo.isEmpty then + Ast() + else + val sourceAst = initNodeInfo.get match + case requireCall if requireCall.code.startsWith(s"$RequireKeyword(") => + val call = astForNodeWithFunctionReference(requireCall.json) + handleRequireCallForDependencies( + createBabelNodeInfo(declarator), + idNodeInfo.json, + initNodeInfo.get.json, + call.root.map(_.asInstanceOf[NewCall]) + ) + call + case initExpr => + astForNodeWithFunctionReference(initExpr.json) + val nodeInfo = createBabelNodeInfo(idNodeInfo.json) + nodeInfo.node match + case ObjectPattern | ArrayPattern => + astForDeconstruction(nodeInfo, sourceAst, declaratorCode) + case _ => + val destAst = idNodeInfo.node match + case Identifier => astForIdentifier(idNodeInfo, Option(typeFullName)) + case _ => astForNode(idNodeInfo.json) + + val assignmentCallAst = + createAssignmentCallAst( + destAst, + sourceAst, + declaratorCode, + line = line(declarator), + column = column(declarator) + ) + assignmentCallAst + end if + end astForVariableDeclarator + + protected def astForTSImportEqualsDeclaration(impDecl: BabelNodeInfo): Ast = + val name = impDecl.json("id")("name").str + val referenceNode = createBabelNodeInfo(impDecl.json("moduleReference")) + val referenceName = referenceNode.node match + case TSExternalModuleReference => referenceNode.json("expression")("value").str + case _ => referenceNode.code + val _dependencyNode = newDependencyNode(name, referenceName, ImportKeyword) + diffGraph.addNode(_dependencyNode) + val assignment = + astForRequireCallFromImport(name, None, referenceName, isImportN = false, impDecl) + val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = + createImportNodeAndAttachToCall(impDecl, referenceName, name, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + assignment + + private def astForRequireCallFromImport( + name: String, + alias: Option[String], + from: String, + isImportN: Boolean, + nodeInfo: BabelNodeInfo + ): Ast = + val destName = alias.getOrElse(name) + val destNode = identifierNode(nodeInfo, destName) + val localNode = newLocalNode(destName, Defines.Any).order(0) + scope.addVariable(destName, localNode, BlockScope) + scope.addVariableReference(destName, destNode) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + val destAst = Ast(destNode) + val sourceCallArgNode = literalNode(nodeInfo, s"\"$from\"", None) + val sourceCall = + callNode( + nodeInfo, + s"$RequireKeyword(${sourceCallArgNode.code})", + RequireKeyword, + DispatchTypes.DYNAMIC_DISPATCH + ) + + val receiverNode = identifierNode(nodeInfo, RequireKeyword) + val thisNode = + identifierNode(nodeInfo, "this").dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariableReference(thisNode.name, thisNode) + val cAst = callAst( + sourceCall, + List(Ast(sourceCallArgNode)), + receiver = Option(Ast(receiverNode)), + base = Option(Ast(thisNode)) + ) + val sourceAst = if isImportN then + val fieldAccessCall = createFieldAccessCallAst( + cAst, + createFieldIdentifierNode(name, nodeInfo.lineNumber, nodeInfo.columnNumber), + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + fieldAccessCall + else + cAst + val assigmentCallAst = createAssignmentCallAst( destAst, sourceAst, - declaratorCode, - line = line(declarator), - column = column(declarator) + s"var ${codeOf(destAst.nodes.head)} = ${codeOf(sourceAst.nodes.head)}", + nodeInfo.lineNumber, + nodeInfo.columnNumber ) - assignmentCallAst - } - } - } - - protected def astForTSImportEqualsDeclaration(impDecl: BabelNodeInfo): Ast = { - val name = impDecl.json("id")("name").str - val referenceNode = createBabelNodeInfo(impDecl.json("moduleReference")) - val referenceName = referenceNode.node match { - case TSExternalModuleReference => referenceNode.json("expression")("value").str - case _ => referenceNode.code - } - val _dependencyNode = newDependencyNode(name, referenceName, ImportKeyword) - diffGraph.addNode(_dependencyNode) - val assignment = astForRequireCallFromImport(name, None, referenceName, isImportN = false, impDecl) - val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = - createImportNodeAndAttachToCall(impDecl, referenceName, name, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - assignment - } - - private def astForRequireCallFromImport( - name: String, - alias: Option[String], - from: String, - isImportN: Boolean, - nodeInfo: BabelNodeInfo - ): Ast = { - val destName = alias.getOrElse(name) - val destNode = identifierNode(nodeInfo, destName) - val localNode = newLocalNode(destName, Defines.Any).order(0) - scope.addVariable(destName, localNode, BlockScope) - scope.addVariableReference(destName, destNode) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - val destAst = Ast(destNode) - val sourceCallArgNode = literalNode(nodeInfo, s"\"$from\"", None) - val sourceCall = - callNode(nodeInfo, s"$RequireKeyword(${sourceCallArgNode.code})", RequireKeyword, DispatchTypes.DYNAMIC_DISPATCH) - - val receiverNode = identifierNode(nodeInfo, RequireKeyword) - val thisNode = identifierNode(nodeInfo, "this").dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariableReference(thisNode.name, thisNode) - val cAst = callAst( - sourceCall, - List(Ast(sourceCallArgNode)), - receiver = Option(Ast(receiverNode)), - base = Option(Ast(thisNode)) - ) - val sourceAst = if (isImportN) { - val fieldAccessCall = createFieldAccessCallAst( - cAst, - createFieldIdentifierNode(name, nodeInfo.lineNumber, nodeInfo.columnNumber), - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - fieldAccessCall - } else { - cAst - } - val assigmentCallAst = - createAssignmentCallAst( - destAst, - sourceAst, - s"var ${codeOf(destAst.nodes.head)} = ${codeOf(sourceAst.nodes.head)}", - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - assigmentCallAst - } - - protected def astForImportDeclaration(impDecl: BabelNodeInfo): Ast = { - val source = impDecl.json("source")("value").str - val specifiers = impDecl.json("specifiers").arr - - if (specifiers.isEmpty) { - val _dependencyNode = newDependencyNode(source, source, ImportKeyword) - diffGraph.addNode(_dependencyNode) - val assignment = astForRequireCallFromImport(source, None, source, isImportN = false, impDecl) - val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = createImportNodeAndAttachToCall(impDecl, source, source, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - assignment - } else { - val specs = impDecl.json("specifiers").arr.toList - val requireCalls = specs.map { importSpecifier => - val isImportN = createBabelNodeInfo(importSpecifier).node match { - case ImportSpecifier => true - case _ => false - } - val name = importSpecifier("local")("name").str - val (alias, reqName) = reqNameFromImportSpecifier(importSpecifier, name) - val assignment = astForRequireCallFromImport(reqName, alias, source, isImportN = isImportN, impDecl) - val importedName = importSpecifier("local")("name").str - val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = createImportNodeAndAttachToCall(impDecl, s"$source:$reqName", importedName, call) - val _dependencyNode = newDependencyNode(importedName, source, ImportKeyword) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - diffGraph.addNode(_dependencyNode) - assignment - } - if (requireCalls.isEmpty) { - Ast() - } else if (requireCalls.sizeIs == 1) { - requireCalls.head - } else { - blockAst(createBlockNode(impDecl), requireCalls) - } - } - } - - private def reqNameFromImportSpecifier(importSpecifier: Value, name: String) = { - if (hasKey(importSpecifier, "imported")) { - (Option(name), importSpecifier("imported")("name").str) - } else { - (None, name) - } - } - - private def createImportNodeAndAttachToCall( - impDecl: BabelNodeInfo, - importedEntity: String, - importedAs: String, - call: Option[NewCall] - ): NewImport = { - createImportNodeAndAttachToCall(impDecl.code.stripSuffix(";"), importedEntity, importedAs, call) - } - - private def createImportNodeAndAttachToCall( - code: String, - importedEntity: String, - importedAs: String, - call: Option[NewCall] - ): NewImport = { - val impNode = NewImport() - .code(code) - .importedEntity(importedEntity) - .importedAs(importedAs) - .lineNumber(call.flatMap(_.lineNumber)) - .columnNumber(call.flatMap(_.lineNumber)) - call.foreach { c => diffGraph.addEdge(c, impNode, EdgeTypes.IS_CALL_FOR_IMPORT) } - impNode - } - - private def convertDestructingObjectElement(element: BabelNodeInfo, key: BabelNodeInfo, localTmpName: String): Ast = { - val valueAst = astForNode(element.json) - - val localNode = newLocalNode(element.code, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(element.code, localNode, MethodScope) - - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - val accessAst = createFieldAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - createAssignmentCallAst( - valueAst, - accessAst, - s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", - element.lineNumber, - element.columnNumber - ) - } - - private def convertDestructingArrayElement(element: BabelNodeInfo, index: Int, localTmpName: String): Ast = { - val valueAst = astForNode(element.json) - - val localNode = newLocalNode(element.code, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(element.code, localNode, MethodScope) - - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - val accessAst = createIndexAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - createAssignmentCallAst( - valueAst, - accessAst, - s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", - element.lineNumber, - element.columnNumber - ) - } - - private def convertDestructingArrayElementWithDefault( - element: BabelNodeInfo, - index: Int, - localTmpName: String - ): Ast = { - val rhsElement = element.json("right") - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsElement = element.json("left") - val nodeInfo = createBabelNodeInfo(lhsElement) - val lhsAst = nodeInfo.node match { - case ObjectPattern | ArrayPattern => - val sourceAst = astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) - astForDeconstruction(nodeInfo, sourceAst, element.code) - case _ => astForNodeWithFunctionReference(lhsElement) - } - - val testAst = { - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - val accessAst = - createIndexAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) - createEqualsCallAst(accessAst, Ast(voidCallNode), element.lineNumber, element.columnNumber) - } - val falseAst = { - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - createIndexAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - } - val ternaryNodeAst = - createTernaryCallAst(testAst, rhsAst, falseAst, element.lineNumber, element.columnNumber) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", - element.lineNumber, - element.columnNumber - ) - } - - private def convertDestructingObjectElementWithDefault( - element: BabelNodeInfo, - key: BabelNodeInfo, - localTmpName: String - ): Ast = { - val rhsElement = element.json("right") - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsElement = element.json("left") - val nodeInfo = createBabelNodeInfo(lhsElement) - val lhsAst = nodeInfo.node match { - case ObjectPattern | ArrayPattern => - val sourceAst = astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) - astForDeconstruction(nodeInfo, sourceAst, element.code) - case _ => astForNodeWithFunctionReference(lhsElement) - } - - val testAst = { - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - val accessAst = - createFieldAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) - createEqualsCallAst(accessAst, Ast(voidCallNode), element.lineNumber, element.columnNumber) - } - val falseAst = { - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - createFieldAccessCallAst(fieldAccessTmpNode, keyNode, element.lineNumber, element.columnNumber) - } - val ternaryNodeAst = - createTernaryCallAst(testAst, rhsAst, falseAst, element.lineNumber, element.columnNumber) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", - element.lineNumber, - element.columnNumber - ) - } - - private def createParamAst(pattern: BabelNodeInfo, keyName: String, sourceAst: Ast): Ast = { - val testAst = { - val lhsNode = identifierNode(pattern, keyName) - scope.addVariableReference(keyName, lhsNode) - val rhsNode = - callNode(pattern, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) - createEqualsCallAst(Ast(lhsNode), Ast(rhsNode), pattern.lineNumber, pattern.columnNumber) - } - - val falseNode = { - val initNode = identifierNode(pattern, keyName) - scope.addVariableReference(keyName, initNode) - initNode - } - createTernaryCallAst(testAst, sourceAst, Ast(falseNode), pattern.lineNumber, pattern.columnNumber) - } - - protected def astForDeconstruction( - pattern: BabelNodeInfo, - sourceAst: Ast, - code: String, - paramName: Option[String] = None - ): Ast = { - val localTmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - - val blockNode = createBlockNode(pattern, Option(code)) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val localNode = newLocalNode(localTmpName, Defines.Any).order(0) - val tmpNode = identifierNode(pattern, localTmpName) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(localTmpName, localNode, BlockScope) - scope.addVariableReference(localTmpName, tmpNode) - - val rhsAssignmentAst = paramName.map(createParamAst(pattern, _, sourceAst)).getOrElse(sourceAst) - val assignmentTmpCallAst = - createAssignmentCallAst( - Ast(tmpNode), - rhsAssignmentAst, - s"$localTmpName = ${codeOf(rhsAssignmentAst.nodes.head)}", - pattern.lineNumber, - pattern.columnNumber - ) - - val subTreeAsts = pattern.node match { - case ObjectPattern => - pattern.json("properties").arr.toList.map { element => - val nodeInfo = createBabelNodeInfo(element) - nodeInfo.node match { - case RestElement => - val arg1Ast = Ast(identifierNode(nodeInfo, localTmpName)) - astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) - case _ => - val nodeInfo = createBabelNodeInfo(element("value")) - nodeInfo.node match { - case Identifier => - convertDestructingObjectElement(nodeInfo, createBabelNodeInfo(element("key")), localTmpName) - case AssignmentPattern => - convertDestructingObjectElementWithDefault( - nodeInfo, - createBabelNodeInfo(element("key")), - localTmpName - ) - case _ => astForNodeWithFunctionReference(nodeInfo.json) - } - } - } - case ArrayPattern => - pattern.json("elements").arr.toList.zipWithIndex.map { - case (element, index) if !element.isNull => - val nodeInfo = createBabelNodeInfo(element) - nodeInfo.node match { - case RestElement => - val fieldAccessTmpNode = identifierNode(nodeInfo, localTmpName) - val keyNode = literalNode(nodeInfo, index.toString, Option(Defines.Number)) - val accessAst = - createIndexAccessCallAst(fieldAccessTmpNode, keyNode, nodeInfo.lineNumber, nodeInfo.columnNumber) - astForSpreadOrRestElement(nodeInfo, Option(accessAst)) - case Identifier => - convertDestructingArrayElement(nodeInfo, index, localTmpName) - case AssignmentPattern => - convertDestructingArrayElementWithDefault(nodeInfo, index, localTmpName) - case _ => astForNodeWithFunctionReference(nodeInfo.json) + assigmentCallAst + end astForRequireCallFromImport + + protected def astForImportDeclaration(impDecl: BabelNodeInfo): Ast = + val source = impDecl.json("source")("value").str + val specifiers = impDecl.json("specifiers").arr + + if specifiers.isEmpty then + val _dependencyNode = newDependencyNode(source, source, ImportKeyword) + diffGraph.addNode(_dependencyNode) + val assignment = + astForRequireCallFromImport(source, None, source, isImportN = false, impDecl) + val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = createImportNodeAndAttachToCall(impDecl, source, source, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + assignment + else + val specs = impDecl.json("specifiers").arr.toList + val requireCalls = specs.map { importSpecifier => + val isImportN = createBabelNodeInfo(importSpecifier).node match + case ImportSpecifier => true + case _ => false + val name = importSpecifier("local")("name").str + val (alias, reqName) = reqNameFromImportSpecifier(importSpecifier, name) + val assignment = astForRequireCallFromImport( + reqName, + alias, + source, + isImportN = isImportN, + impDecl + ) + val importedName = importSpecifier("local")("name").str + val call = + assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = createImportNodeAndAttachToCall( + impDecl, + s"$source:$reqName", + importedName, + call + ) + val _dependencyNode = newDependencyNode(importedName, source, ImportKeyword) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + diffGraph.addNode(_dependencyNode) + assignment } - case _ => Ast() - } - case _ => - List(convertDestructingObjectElement(pattern, pattern, localTmpName)) - } + if requireCalls.isEmpty then + Ast() + else if requireCalls.sizeIs == 1 then + requireCalls.head + else + blockAst(createBlockNode(impDecl), requireCalls) + end if + end astForImportDeclaration + + private def reqNameFromImportSpecifier(importSpecifier: Value, name: String) = + if hasKey(importSpecifier, "imported") then + (Option(name), importSpecifier("imported")("name").str) + else + (None, name) + + private def createImportNodeAndAttachToCall( + impDecl: BabelNodeInfo, + importedEntity: String, + importedAs: String, + call: Option[NewCall] + ): NewImport = + createImportNodeAndAttachToCall( + impDecl.code.stripSuffix(";"), + importedEntity, + importedAs, + call + ) + + private def createImportNodeAndAttachToCall( + code: String, + importedEntity: String, + importedAs: String, + call: Option[NewCall] + ): NewImport = + val impNode = NewImport() + .code(code) + .importedEntity(importedEntity) + .importedAs(importedAs) + .lineNumber(call.flatMap(_.lineNumber)) + .columnNumber(call.flatMap(_.lineNumber)) + call.foreach { c => diffGraph.addEdge(c, impNode, EdgeTypes.IS_CALL_FOR_IMPORT) } + impNode + + private def convertDestructingObjectElement( + element: BabelNodeInfo, + key: BabelNodeInfo, + localTmpName: String + ): Ast = + val valueAst = astForNode(element.json) + + val localNode = newLocalNode(element.code, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(element.code, localNode, MethodScope) + + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + val accessAst = createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + valueAst, + accessAst, + s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingObjectElement + + private def convertDestructingArrayElement( + element: BabelNodeInfo, + index: Int, + localTmpName: String + ): Ast = + val valueAst = astForNode(element.json) + + val localNode = newLocalNode(element.code, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(element.code, localNode, MethodScope) + + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + val accessAst = createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + valueAst, + accessAst, + s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingArrayElement + + private def convertDestructingArrayElementWithDefault( + element: BabelNodeInfo, + index: Int, + localTmpName: String + ): Ast = + val rhsElement = element.json("right") + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsElement = element.json("left") + val nodeInfo = createBabelNodeInfo(lhsElement) + val lhsAst = nodeInfo.node match + case ObjectPattern | ArrayPattern => + val sourceAst = + astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) + astForDeconstruction(nodeInfo, sourceAst, element.code) + case _ => astForNodeWithFunctionReference(lhsElement) + + val testAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + val accessAst = + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) + createEqualsCallAst( + accessAst, + Ast(voidCallNode), + element.lineNumber, + element.columnNumber + ) + val falseAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + falseAst, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingArrayElementWithDefault + + private def convertDestructingObjectElementWithDefault( + element: BabelNodeInfo, + key: BabelNodeInfo, + localTmpName: String + ): Ast = + val rhsElement = element.json("right") + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsElement = element.json("left") + val nodeInfo = createBabelNodeInfo(lhsElement) + val lhsAst = nodeInfo.node match + case ObjectPattern | ArrayPattern => + val sourceAst = + astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) + astForDeconstruction(nodeInfo, sourceAst, element.code) + case _ => astForNodeWithFunctionReference(lhsElement) + + val testAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + val accessAst = + createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) + createEqualsCallAst( + accessAst, + Ast(voidCallNode), + element.lineNumber, + element.columnNumber + ) + val falseAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + falseAst, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingObjectElementWithDefault + + private def createParamAst(pattern: BabelNodeInfo, keyName: String, sourceAst: Ast): Ast = + val testAst = + val lhsNode = identifierNode(pattern, keyName) + scope.addVariableReference(keyName, lhsNode) + val rhsNode = + callNode(pattern, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) + createEqualsCallAst( + Ast(lhsNode), + Ast(rhsNode), + pattern.lineNumber, + pattern.columnNumber + ) + + val falseNode = + val initNode = identifierNode(pattern, keyName) + scope.addVariableReference(keyName, initNode) + initNode + createTernaryCallAst( + testAst, + sourceAst, + Ast(falseNode), + pattern.lineNumber, + pattern.columnNumber + ) + end createParamAst + + protected def astForDeconstruction( + pattern: BabelNodeInfo, + sourceAst: Ast, + code: String, + paramName: Option[String] = None + ): Ast = + val localTmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + + val blockNode = createBlockNode(pattern, Option(code)) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val localNode = newLocalNode(localTmpName, Defines.Any).order(0) + val tmpNode = identifierNode(pattern, localTmpName) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(localTmpName, localNode, BlockScope) + scope.addVariableReference(localTmpName, tmpNode) + + val rhsAssignmentAst = + paramName.map(createParamAst(pattern, _, sourceAst)).getOrElse(sourceAst) + val assignmentTmpCallAst = + createAssignmentCallAst( + Ast(tmpNode), + rhsAssignmentAst, + s"$localTmpName = ${codeOf(rhsAssignmentAst.nodes.head)}", + pattern.lineNumber, + pattern.columnNumber + ) - val returnTmpNode = identifierNode(pattern, localTmpName) - scope.popScope() - localAstParentStack.pop() + val subTreeAsts = pattern.node match + case ObjectPattern => + pattern.json("properties").arr.toList.map { element => + val nodeInfo = createBabelNodeInfo(element) + nodeInfo.node match + case RestElement => + val arg1Ast = Ast(identifierNode(nodeInfo, localTmpName)) + astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) + case _ => + val nodeInfo = createBabelNodeInfo(element("value")) + nodeInfo.node match + case Identifier => + convertDestructingObjectElement( + nodeInfo, + createBabelNodeInfo(element("key")), + localTmpName + ) + case AssignmentPattern => + convertDestructingObjectElementWithDefault( + nodeInfo, + createBabelNodeInfo(element("key")), + localTmpName + ) + case _ => astForNodeWithFunctionReference(nodeInfo.json) + end match + } + case ArrayPattern => + pattern.json("elements").arr.toList.zipWithIndex.map { + case (element, index) if !element.isNull => + val nodeInfo = createBabelNodeInfo(element) + nodeInfo.node match + case RestElement => + val fieldAccessTmpNode = identifierNode(nodeInfo, localTmpName) + val keyNode = + literalNode(nodeInfo, index.toString, Option(Defines.Number)) + val accessAst = + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + astForSpreadOrRestElement(nodeInfo, Option(accessAst)) + case Identifier => + convertDestructingArrayElement(nodeInfo, index, localTmpName) + case AssignmentPattern => + convertDestructingArrayElementWithDefault( + nodeInfo, + index, + localTmpName + ) + case _ => astForNodeWithFunctionReference(nodeInfo.json) + end match + case _ => Ast() + } + case _ => + List(convertDestructingObjectElement(pattern, pattern, localTmpName)) - val blockChildren = assignmentTmpCallAst +: subTreeAsts :+ Ast(returnTmpNode) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - } + val returnTmpNode = identifierNode(pattern, localTmpName) + scope.popScope() + localAstParentStack.pop() -} + val blockChildren = assignmentTmpCallAst +: subTreeAsts :+ Ast(returnTmpNode) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForDeconstruction +end AstForDeclarationsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala index cfe1c107..cc55acd8 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -11,482 +11,575 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Opera import scala.util.Try -trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - protected def astForExpressionStatement(exprStmt: BabelNodeInfo): Ast = - astForNodeWithFunctionReference(exprStmt.json("expression")) - - private def createBuiltinStaticCall(callExpr: BabelNodeInfo, callee: BabelNodeInfo, fullName: String): Ast = { - val callName = callee.node match { - case MemberExpression => code(callee.json("property")) - case _ => callee.code - } - val callNode = - createStaticCallNode(callExpr.code, callName, fullName, callee.lineNumber, callee.columnNumber) - val argAsts = astForNodes(callExpr.json("arguments").arr.toList) - callAst(callNode, argAsts) - } - - private def handleCallNodeArgs( - callExpr: BabelNodeInfo, - receiverAst: Ast, - baseNode: NewNode, - callName: String - ): Ast = { - val args = astForNodes(callExpr.json("arguments").arr.toList) - val callNode_ = callNode(callExpr, callExpr.code, callName, DispatchTypes.DYNAMIC_DISPATCH) - // If the callee is a function itself, e.g. closure, then resolve this locally, if possible - callExpr.json.obj - .get("callee") - .map(createBabelNodeInfo) - .flatMap { - case callee if callee.node.isInstanceOf[FunctionLike] => functionNodeToNameAndFullName.get(callee) - case _ => None - } - .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } - callAst(callNode_, args, receiver = Option(receiverAst), base = Option(Ast(baseNode))) - } - - protected def astForCallExpression(callExpr: BabelNodeInfo): Ast = { - val callee = createBabelNodeInfo(callExpr.json("callee")) - val calleeCode = callee.code - if (GlobalBuiltins.builtins.contains(calleeCode)) { - createBuiltinStaticCall(callExpr, callee, calleeCode) - } else { - val (receiverAst, baseNode, callName) = callee.node match { - case MemberExpression => - val base = createBabelNodeInfo(callee.json("object")) - val member = createBabelNodeInfo(callee.json("property")) - base.node match { - case ThisExpression => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code) - .dynamicTypeHintFullName(this.rootTypeDecl.map(_.fullName).toSeq) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) - case Identifier => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) +trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForExpressionStatement(exprStmt: BabelNodeInfo): Ast = + astForNodeWithFunctionReference(exprStmt.json("expression")) + + private def createBuiltinStaticCall( + callExpr: BabelNodeInfo, + callee: BabelNodeInfo, + fullName: String + ): Ast = + val callName = callee.node match + case MemberExpression => code(callee.json("property")) + case _ => callee.code + val callNode = + createStaticCallNode( + callExpr.code, + callName, + fullName, + callee.lineNumber, + callee.columnNumber + ) + val argAsts = astForNodes(callExpr.json("arguments").arr.toList) + callAst(callNode, argAsts) + + private def handleCallNodeArgs( + callExpr: BabelNodeInfo, + receiverAst: Ast, + baseNode: NewNode, + callName: String + ): Ast = + val args = astForNodes(callExpr.json("arguments").arr.toList) + val callNode_ = callNode(callExpr, callExpr.code, callName, DispatchTypes.DYNAMIC_DISPATCH) + // If the callee is a function itself, e.g. closure, then resolve this locally, if possible + callExpr.json.obj + .get("callee") + .map(createBabelNodeInfo) + .flatMap { + case callee if callee.node.isInstanceOf[FunctionLike] => + functionNodeToNameAndFullName.get(callee) + case _ => None + } + .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } + callAst(callNode_, args, receiver = Option(receiverAst), base = Option(Ast(baseNode))) + end handleCallNodeArgs + + protected def astForCallExpression(callExpr: BabelNodeInfo): Ast = + val callee = createBabelNodeInfo(callExpr.json("callee")) + val calleeCode = callee.code + if GlobalBuiltins.builtins.contains(calleeCode) then + createBuiltinStaticCall(callExpr, callee, calleeCode) + else + val (receiverAst, baseNode, callName) = callee.node match + case MemberExpression => + val base = createBabelNodeInfo(callee.json("object")) + val member = createBabelNodeInfo(callee.json("property")) + base.node match + case ThisExpression => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val baseNode = identifierNode(base, base.code) + .dynamicTypeHintFullName(this.rootTypeDecl.map(_.fullName).toSeq) + scope.addVariableReference(base.code, baseNode) + (receiverAst, baseNode, member.code) + case Identifier => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val baseNode = identifierNode(base, base.code) + scope.addVariableReference(base.code, baseNode) + (receiverAst, baseNode, member.code) + case _ => + val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") + val baseTmpNode = identifierNode(base, tmpVarName) + scope.addVariableReference(tmpVarName, baseTmpNode) + val baseAst = astForNodeWithFunctionReference(base.json) + val code = s"(${codeOf(baseTmpNode)} = ${base.code})" + val tmpAssignmentAst = + createAssignmentCallAst( + Ast(baseTmpNode), + baseAst, + code, + base.lineNumber, + base.columnNumber + ) + val memberNode = createFieldIdentifierNode( + member.code, + member.lineNumber, + member.columnNumber + ) + val fieldAccessAst = + createFieldAccessCallAst( + tmpAssignmentAst, + memberNode, + callee.lineNumber, + callee.columnNumber + ) + val thisTmpNode = identifierNode(callee, tmpVarName) + scope.addVariableReference(tmpVarName, thisTmpNode) + + (fieldAccessAst, thisTmpNode, member.code) + end match + case _ => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val thisNode = identifierNode(callee, "this").dynamicTypeHintFullName( + typeHintForThisExpression() + ) + scope.addVariableReference(thisNode.name, thisNode) + (receiverAst, thisNode, calleeCode) + handleCallNodeArgs(callExpr, receiverAst, baseNode, callName) + end if + end astForCallExpression + + protected def astForThisExpression(thisExpr: BabelNodeInfo): Ast = + val dynamicTypeOption = typeHintForThisExpression(Option(thisExpr)).headOption + val thisNode = identifierNode(thisExpr, thisExpr.code, dynamicTypeOption.toList) + scope.addVariableReference(thisExpr.code, thisNode) + Ast(thisNode) + + protected def astForNewExpression(newExpr: BabelNodeInfo): Ast = + val callee = newExpr.json("callee") + val blockNode = createBlockNode(newExpr) + + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpAllocName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localTmpAllocNode = newLocalNode(tmpAllocName, Defines.Any).order(0) + val tmpAllocNode1 = identifierNode(newExpr, tmpAllocName) + diffGraph.addEdge(localAstParentStack.head, localTmpAllocNode, EdgeTypes.AST) + scope.addVariableReference(tmpAllocName, tmpAllocNode1) + + val allocCallNode = + callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) + + val assignmentTmpAllocCallNode = + createAssignmentCallAst( + tmpAllocNode1, + allocCallNode, + s"$tmpAllocName = ${allocCallNode.code}", + newExpr.lineNumber, + newExpr.columnNumber + ) + + val tmpAllocNode2 = identifierNode(newExpr, tmpAllocName) + + val receiverNode = astForNodeWithFunctionReference(callee) + + // TODO: place ".new" into the schema + val callAst = handleCallNodeArgs(newExpr, receiverNode, tmpAllocNode2, ".new") + + val tmpAllocReturnNode = Ast(identifierNode(newExpr, tmpAllocName)) + + scope.popScope() + localAstParentStack.pop() + + val blockChildren = List(assignmentTmpAllocCallNode, callAst, tmpAllocReturnNode) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForNewExpression + + protected def astForMetaProperty(metaProperty: BabelNodeInfo): Ast = + val metaAst = astForIdentifier(createBabelNodeInfo(metaProperty.json("meta"))) + val memberNodeInfo = createBabelNodeInfo(metaProperty.json("property")) + val memberAst = Ast( + createFieldIdentifierNode( + memberNodeInfo.code, + memberNodeInfo.lineNumber, + memberNodeInfo.columnNumber + ) + ) + createFieldAccessCallAst( + metaAst, + memberAst.nodes.head, + metaProperty.lineNumber, + metaProperty.columnNumber + ) + + protected def astForMemberExpression(memberExpr: BabelNodeInfo): Ast = + val baseAst = astForNodeWithFunctionReference(memberExpr.json("object")) + val memberIsComputed = memberExpr.json("computed").bool + val memberNodeInfo = createBabelNodeInfo(memberExpr.json("property")) + if memberIsComputed then + val memberAst = astForNode(memberNodeInfo.json) + createIndexAccessCallAst( + baseAst, + memberAst, + memberExpr.lineNumber, + memberExpr.columnNumber + ) + else + val memberAst = Ast( + createFieldIdentifierNode( + memberNodeInfo.code, + memberNodeInfo.lineNumber, + memberNodeInfo.columnNumber + ) + ) + createFieldAccessCallAst( + baseAst, + memberAst.nodes.head, + memberExpr.lineNumber, + memberExpr.columnNumber + ) + end if + end astForMemberExpression + + protected def astForAssignmentExpression(assignment: BabelNodeInfo): Ast = + val op = if hasKey(assignment.json, "operator") then + assignment.json("operator").str match + case "=" => Operators.assignment + case "+=" => Operators.assignmentPlus + case "-=" => Operators.assignmentMinus + case "*=" => Operators.assignmentMultiplication + case "/=" => Operators.assignmentDivision + case "%=" => Operators.assignmentModulo + case "**=" => Operators.assignmentExponentiation + case "&=" => Operators.assignmentAnd + case "&&=" => Operators.assignmentAnd + case "|=" => Operators.assignmentOr + case "||=" => Operators.assignmentOr + case "^=" => Operators.assignmentXor + case "<<=" => Operators.assignmentShiftLeft + case ">>=" => Operators.assignmentArithmeticShiftRight + case ">>>=" => Operators.assignmentLogicalShiftRight + case "??=" => Operators.notNullAssert + case other => + logger.warn(s"Unknown assignment operator: '$other'") + Operators.assignment + else Operators.assignment + + val nodeInfo = createBabelNodeInfo(assignment.json("left")) + nodeInfo.node match + case ObjectPattern | ArrayPattern => + val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) + astForDeconstruction(nodeInfo, rhsAst, assignment.code) case _ => - val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") - val baseTmpNode = identifierNode(base, tmpVarName) - scope.addVariableReference(tmpVarName, baseTmpNode) - val baseAst = astForNodeWithFunctionReference(base.json) - val code = s"(${codeOf(baseTmpNode)} = ${base.code})" - val tmpAssignmentAst = - createAssignmentCallAst(Ast(baseTmpNode), baseAst, code, base.lineNumber, base.columnNumber) - val memberNode = createFieldIdentifierNode(member.code, member.lineNumber, member.columnNumber) - val fieldAccessAst = - createFieldAccessCallAst(tmpAssignmentAst, memberNode, callee.lineNumber, callee.columnNumber) - val thisTmpNode = identifierNode(callee, tmpVarName) - scope.addVariableReference(tmpVarName, thisTmpNode) - - (fieldAccessAst, thisTmpNode, member.code) - } - case _ => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val thisNode = identifierNode(callee, "this").dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariableReference(thisNode.name, thisNode) - (receiverAst, thisNode, calleeCode) - } - handleCallNodeArgs(callExpr, receiverAst, baseNode, callName) - } - } - - protected def astForThisExpression(thisExpr: BabelNodeInfo): Ast = { - val dynamicTypeOption = typeHintForThisExpression(Option(thisExpr)).headOption - val thisNode = identifierNode(thisExpr, thisExpr.code, dynamicTypeOption.toList) - scope.addVariableReference(thisExpr.code, thisNode) - Ast(thisNode) - } - - protected def astForNewExpression(newExpr: BabelNodeInfo): Ast = { - val callee = newExpr.json("callee") - val blockNode = createBlockNode(newExpr) - - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpAllocName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localTmpAllocNode = newLocalNode(tmpAllocName, Defines.Any).order(0) - val tmpAllocNode1 = identifierNode(newExpr, tmpAllocName) - diffGraph.addEdge(localAstParentStack.head, localTmpAllocNode, EdgeTypes.AST) - scope.addVariableReference(tmpAllocName, tmpAllocNode1) - - val allocCallNode = - callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) - - val assignmentTmpAllocCallNode = - createAssignmentCallAst( - tmpAllocNode1, - allocCallNode, - s"$tmpAllocName = ${allocCallNode.code}", - newExpr.lineNumber, - newExpr.columnNumber - ) - - val tmpAllocNode2 = identifierNode(newExpr, tmpAllocName) - - val receiverNode = astForNodeWithFunctionReference(callee) - - // TODO: place ".new" into the schema - val callAst = handleCallNodeArgs(newExpr, receiverNode, tmpAllocNode2, ".new") - - val tmpAllocReturnNode = Ast(identifierNode(newExpr, tmpAllocName)) - - scope.popScope() - localAstParentStack.pop() - - val blockChildren = List(assignmentTmpAllocCallNode, callAst, tmpAllocReturnNode) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - } - - protected def astForMetaProperty(metaProperty: BabelNodeInfo): Ast = { - val metaAst = astForIdentifier(createBabelNodeInfo(metaProperty.json("meta"))) - val memberNodeInfo = createBabelNodeInfo(metaProperty.json("property")) - val memberAst = Ast( - createFieldIdentifierNode(memberNodeInfo.code, memberNodeInfo.lineNumber, memberNodeInfo.columnNumber) - ) - createFieldAccessCallAst(metaAst, memberAst.nodes.head, metaProperty.lineNumber, metaProperty.columnNumber) - } - - protected def astForMemberExpression(memberExpr: BabelNodeInfo): Ast = { - val baseAst = astForNodeWithFunctionReference(memberExpr.json("object")) - val memberIsComputed = memberExpr.json("computed").bool - val memberNodeInfo = createBabelNodeInfo(memberExpr.json("property")) - if (memberIsComputed) { - val memberAst = astForNode(memberNodeInfo.json) - createIndexAccessCallAst(baseAst, memberAst, memberExpr.lineNumber, memberExpr.columnNumber) - } else { - val memberAst = Ast( - createFieldIdentifierNode(memberNodeInfo.code, memberNodeInfo.lineNumber, memberNodeInfo.columnNumber) - ) - createFieldAccessCallAst(baseAst, memberAst.nodes.head, memberExpr.lineNumber, memberExpr.columnNumber) - } - } - - protected def astForAssignmentExpression(assignment: BabelNodeInfo): Ast = { - val op = if (hasKey(assignment.json, "operator")) { - assignment.json("operator").str match { - case "=" => Operators.assignment - case "+=" => Operators.assignmentPlus - case "-=" => Operators.assignmentMinus - case "*=" => Operators.assignmentMultiplication - case "/=" => Operators.assignmentDivision - case "%=" => Operators.assignmentModulo - case "**=" => Operators.assignmentExponentiation - case "&=" => Operators.assignmentAnd - case "&&=" => Operators.assignmentAnd - case "|=" => Operators.assignmentOr - case "||=" => Operators.assignmentOr - case "^=" => Operators.assignmentXor - case "<<=" => Operators.assignmentShiftLeft - case ">>=" => Operators.assignmentArithmeticShiftRight - case ">>>=" => Operators.assignmentLogicalShiftRight - case "??=" => Operators.notNullAssert - case other => - logger.warn(s"Unknown assignment operator: '$other'") - Operators.assignment - } - } else Operators.assignment - - val nodeInfo = createBabelNodeInfo(assignment.json("left")) - nodeInfo.node match { - case ObjectPattern | ArrayPattern => - val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) - astForDeconstruction(nodeInfo, rhsAst, assignment.code) - case _ => - val lhsAst = astForNode(assignment.json("left")) - val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) + val lhsAst = astForNode(assignment.json("left")) + val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) + val callNode_ = + callNode(assignment, assignment.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(lhsAst, rhsAst) + callAst(callNode_, argAsts) + end astForAssignmentExpression + + protected def astForConditionalExpression(ternary: BabelNodeInfo): Ast = + val testAst = astForNodeWithFunctionReference(ternary.json("test")) + val consequentAst = astForNodeWithFunctionReference(ternary.json("consequent")) + val alternateAst = astForNodeWithFunctionReference(ternary.json("alternate")) + createTernaryCallAst( + testAst, + consequentAst, + alternateAst, + ternary.lineNumber, + ternary.columnNumber + ) + + protected def astForLogicalExpression(logicalExpr: BabelNodeInfo): Ast = + astForBinaryExpression(logicalExpr) + + protected def astForTSNonNullExpression(nonNullExpr: BabelNodeInfo): Ast = + val op = Operators.notNullAssert val callNode_ = - callNode(assignment, assignment.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) + callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) callAst(callNode_, argAsts) - } - } - - protected def astForConditionalExpression(ternary: BabelNodeInfo): Ast = { - val testAst = astForNodeWithFunctionReference(ternary.json("test")) - val consequentAst = astForNodeWithFunctionReference(ternary.json("consequent")) - val alternateAst = astForNodeWithFunctionReference(ternary.json("alternate")) - createTernaryCallAst(testAst, consequentAst, alternateAst, ternary.lineNumber, ternary.columnNumber) - } - - protected def astForLogicalExpression(logicalExpr: BabelNodeInfo): Ast = - astForBinaryExpression(logicalExpr) - - protected def astForTSNonNullExpression(nonNullExpr: BabelNodeInfo): Ast = { - val op = Operators.notNullAssert - val callNode_ = - callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) - callAst(callNode_, argAsts) - } - - protected def astForCastExpression(castExpr: BabelNodeInfo): Ast = { - val op = Operators.cast - val typ = typeFor(castExpr) match { - case t if GlobalBuiltins.builtins.contains(t) => s"__ecma.$t" - case t => t - } - val lhsNode = castExpr.json("typeAnnotation") - val lhsAst = Ast(literalNode(castExpr, code(lhsNode), None).dynamicTypeHintFullName(Seq(typ))) - val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) - val node = - callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH) - .dynamicTypeHintFullName(Seq(typ)) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - } - - protected def astForBinaryExpression(binExpr: BabelNodeInfo): Ast = { - val op = binExpr.json("operator").str match { - case "+" => Operators.addition - case "-" => Operators.subtraction - case "/" => Operators.division - case "%" => Operators.modulo - case "*" => Operators.multiplication - case "**" => Operators.exponentiation - case "&" => Operators.and - case ">>" => Operators.arithmeticShiftRight - case ">>>" => Operators.arithmeticShiftRight - case "<<" => Operators.shiftLeft - case "^" => Operators.xor - case "==" => Operators.equals - case "===" => Operators.equals - case "!=" => Operators.notEquals - case "!==" => Operators.notEquals - case "in" => Operators.in - case ">" => Operators.greaterThan - case "<" => Operators.lessThan - case ">=" => Operators.greaterEqualsThan - case "<=" => Operators.lessEqualsThan - case "instanceof" => Operators.instanceOf - case "||" => Operators.logicalOr - case "|" => Operators.or - case "&&" => Operators.logicalAnd - // special case (see: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Nullish_coalescing_operator) - case "??" => Operators.logicalOr - case "case" => ".case" - case other => - logger.warn(s"Unknown binary operator: '$other'") - Operators.assignment - } - - val lhsAst = astForNodeWithFunctionReference(binExpr.json("left")) - val rhsAst = astForNodeWithFunctionReference(binExpr.json("right")) - - val node = - callNode(binExpr, binExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - } - - protected def astForUpdateExpression(updateExpr: BabelNodeInfo): Ast = { - val isPrefix = updateExpr.json("prefix").bool - val op = updateExpr.json("operator").str match { - case "++" if isPrefix => Operators.preIncrement - case "++" => Operators.postIncrement - case "--" if isPrefix => Operators.preIncrement - case "--" => Operators.postIncrement - case other => - logger.warn(s"Unknown update operator: '$other'") - Operators.assignment - } - - val argumentAst = astForNodeWithFunctionReference(updateExpr.json("argument")) - - val node = callNode(updateExpr, updateExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(node, argAsts) - } - - protected def astForUnaryExpression(unaryExpr: BabelNodeInfo): Ast = { - val op = unaryExpr.json("operator").str match { - case "void" => ".void" - case "throw" => ".throw" - case "delete" => Operators.delete - case "!" => Operators.logicalNot - case "+" => Operators.plus - case "-" => Operators.minus - case "~" => ".bitNot" - case "typeof" => Operators.instanceOf - case other => - logger.warn(s"Unknown update operator: '$other'") - Operators.assignment - } - - val argumentAst = astForNodeWithFunctionReference(unaryExpr.json("argument")) - - val node = callNode(unaryExpr, unaryExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(node, argAsts) - } - - protected def astForSequenceExpression(seq: BabelNodeInfo): Ast = { - val blockNode = createBlockNode(seq) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val sequenceExpressionAsts = createBlockStatementAsts(seq.json("expressions")) - setArgumentIndices(sequenceExpressionAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, sequenceExpressionAsts) - } - - protected def astForAwaitExpression(awaitExpr: BabelNodeInfo): Ast = { - val node = - callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) - val argAsts = List(astForNodeWithFunctionReference(awaitExpr.json("argument"))) - callAst(node, argAsts) - } - - protected def astForArrayExpression(arrExpr: BabelNodeInfo): Ast = { - val elements = Try(arrExpr.json("elements").arr).toOption.toList.flatten - if (elements.isEmpty) { - Ast( - callNode(arrExpr, s"${EcmaBuiltins.arrayFactory}()", EcmaBuiltins.arrayFactory, DispatchTypes.STATIC_DISPATCH) - ) - } else { - val blockNode = createBlockNode(arrExpr) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localTmpNode = newLocalNode(tmpName, Defines.Any).order(0) - val tmpArrayNode = identifierNode(arrExpr, tmpName) - diffGraph.addEdge(localAstParentStack.head, localTmpNode, EdgeTypes.AST) - scope.addVariableReference(tmpName, tmpArrayNode) - - val arrayCallNode = - callNode(arrExpr, s"${EcmaBuiltins.arrayFactory}()", EcmaBuiltins.arrayFactory, DispatchTypes.STATIC_DISPATCH) - - val lineNumber = arrExpr.lineNumber - val columnNumber = arrExpr.columnNumber - val assignmentCode = s"${localTmpNode.code} = ${arrayCallNode.code}" - val assignmentTmpArrayCallNode = - createAssignmentCallAst(tmpArrayNode, arrayCallNode, assignmentCode, lineNumber, columnNumber) - - val elementAsts = elements.flatMap { - case element if !element.isNull => - val elementNodeInfo = createBabelNodeInfo(element) - val elementLineNumber = elementNodeInfo.lineNumber - val elementColumnNumber = elementNodeInfo.columnNumber - val elementCode = elementNodeInfo.code - val elementNode = elementNodeInfo.node match { - case RestElement => - val arg1Ast = Ast(identifierNode(arrExpr, tmpName)) - astForSpreadOrRestElement(elementNodeInfo, Option(arg1Ast)) - case _ => - astForNodeWithFunctionReference(element) - } - - val pushCallNode = - callNode(elementNodeInfo, s"$tmpName.push($elementCode)", "", DispatchTypes.DYNAMIC_DISPATCH) - - val baseNode = identifierNode(elementNodeInfo, tmpName) - val memberNode = createFieldIdentifierNode("push", elementLineNumber, elementColumnNumber) - val receiverNode = createFieldAccessCallAst(baseNode, memberNode, elementLineNumber, elementColumnNumber) - val thisPushNode = identifierNode(elementNodeInfo, tmpName) - - Option( - callAst(pushCallNode, List(elementNode), receiver = Option(receiverNode), base = Option(Ast(thisPushNode))) - ) - case _ => None // skip - } - - val tmpArrayReturnNode = identifierNode(arrExpr, tmpName) - - scope.popScope() - localAstParentStack.pop() - - val blockChildrenAsts = assignmentTmpArrayCallNode +: elementAsts :+ Ast(tmpArrayReturnNode) - setArgumentIndices(blockChildrenAsts) - blockAst(blockNode, blockChildrenAsts) - } - } - - def astForTemplateExpression(templateExpr: BabelNodeInfo): Ast = { - val argumentAst = astForNodeWithFunctionReference(templateExpr.json("quasi")) - val callName = code(templateExpr.json("tag")) - val callCode = s"$callName(${codeOf(argumentAst.nodes.head)})" - val templateExprCall = - callNode(templateExpr, callCode, callName, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(templateExprCall, argAsts) - } - - protected def astForObjectExpression(objExpr: BabelNodeInfo): Ast = { - val blockNode = createBlockNode(objExpr) - - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localNode = newLocalNode(tmpName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - val propertiesAsts = objExpr.json("properties").arr.toList.map { property => - val nodeInfo = createBabelNodeInfo(property) - nodeInfo.node match { - case SpreadElement | RestElement => - val arg1Ast = Ast(identifierNode(nodeInfo, tmpName)) - astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) - case _ => - val (lhsNode, rhsAst) = nodeInfo.node match { - case ObjectMethod => - val keyName = - if (hasKey(nodeInfo.json("key"), "name")) nodeInfo.json("key")("name").str - else code(nodeInfo.json("key")) - val keyNode = createFieldIdentifierNode(keyName, nodeInfo.lineNumber, nodeInfo.columnNumber) - (keyNode, astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true)) - case ObjectProperty => - val key = createBabelNodeInfo(nodeInfo.json("key")) - val keyName = key.node match { - case Identifier if nodeInfo.json("computed").bool => - key.code - case _ if nodeInfo.json("computed").bool => - generateUnusedVariableName(usedVariableNames, "_computed_object_property") - case _ => key.code - } - val keyNode = createFieldIdentifierNode(keyName, nodeInfo.lineNumber, nodeInfo.columnNumber) - val ast = astForNodeWithFunctionReference(nodeInfo.json("value")) - (keyNode, ast) - case _ => - // can't happen as per https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts#L573 - // just to make the compiler happy here. - ??? - } - - val leftHandSideTmpNode = identifierNode(nodeInfo, tmpName) - val leftHandSideFieldAccessAst = - createFieldAccessCallAst(leftHandSideTmpNode, lhsNode, nodeInfo.lineNumber, nodeInfo.columnNumber) - - createAssignmentCallAst( - leftHandSideFieldAccessAst, - rhsAst, - s"$tmpName.${lhsNode.canonicalName} = ${codeOf(rhsAst.nodes.head)}", - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - } - } - - val tmpNode = identifierNode(objExpr, tmpName) - - scope.popScope() - localAstParentStack.pop() - val childrenAsts = propertiesAsts :+ Ast(tmpNode) - setArgumentIndices(childrenAsts) - blockAst(blockNode, childrenAsts) - } - - protected def astForTSSatisfiesExpression(satisfiesExpr: BabelNodeInfo): Ast = { - // Ignores the type, i.e. `x satisfies T` is understood as `x`. - astForNode(satisfiesExpr.json("expression")) - } -} + protected def astForCastExpression(castExpr: BabelNodeInfo): Ast = + val op = Operators.cast + val typ = typeFor(castExpr) match + case t if GlobalBuiltins.builtins.contains(t) => s"__ecma.$t" + case t => t + val lhsNode = castExpr.json("typeAnnotation") + val lhsAst = + Ast(literalNode(castExpr, code(lhsNode), None).dynamicTypeHintFullName(Seq(typ))) + val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) + val node = + callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH) + .dynamicTypeHintFullName(Seq(typ)) + val argAsts = List(lhsAst, rhsAst) + callAst(node, argAsts) + + protected def astForBinaryExpression(binExpr: BabelNodeInfo): Ast = + val op = binExpr.json("operator").str match + case "+" => Operators.addition + case "-" => Operators.subtraction + case "/" => Operators.division + case "%" => Operators.modulo + case "*" => Operators.multiplication + case "**" => Operators.exponentiation + case "&" => Operators.and + case ">>" => Operators.arithmeticShiftRight + case ">>>" => Operators.arithmeticShiftRight + case "<<" => Operators.shiftLeft + case "^" => Operators.xor + case "==" => Operators.equals + case "===" => Operators.equals + case "!=" => Operators.notEquals + case "!==" => Operators.notEquals + case "in" => Operators.in + case ">" => Operators.greaterThan + case "<" => Operators.lessThan + case ">=" => Operators.greaterEqualsThan + case "<=" => Operators.lessEqualsThan + case "instanceof" => Operators.instanceOf + case "||" => Operators.logicalOr + case "|" => Operators.or + case "&&" => Operators.logicalAnd + // special case (see: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Nullish_coalescing_operator) + case "??" => Operators.logicalOr + case "case" => ".case" + case other => + logger.warn(s"Unknown binary operator: '$other'") + Operators.assignment + + val lhsAst = astForNodeWithFunctionReference(binExpr.json("left")) + val rhsAst = astForNodeWithFunctionReference(binExpr.json("right")) + + val node = + callNode(binExpr, binExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(lhsAst, rhsAst) + callAst(node, argAsts) + end astForBinaryExpression + + protected def astForUpdateExpression(updateExpr: BabelNodeInfo): Ast = + val isPrefix = updateExpr.json("prefix").bool + val op = updateExpr.json("operator").str match + case "++" if isPrefix => Operators.preIncrement + case "++" => Operators.postIncrement + case "--" if isPrefix => Operators.preIncrement + case "--" => Operators.postIncrement + case other => + logger.warn(s"Unknown update operator: '$other'") + Operators.assignment + + val argumentAst = astForNodeWithFunctionReference(updateExpr.json("argument")) + + val node = callNode(updateExpr, updateExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(node, argAsts) + + protected def astForUnaryExpression(unaryExpr: BabelNodeInfo): Ast = + val op = unaryExpr.json("operator").str match + case "void" => ".void" + case "throw" => ".throw" + case "delete" => Operators.delete + case "!" => Operators.logicalNot + case "+" => Operators.plus + case "-" => Operators.minus + case "~" => ".bitNot" + case "typeof" => Operators.instanceOf + case other => + logger.warn(s"Unknown update operator: '$other'") + Operators.assignment + + val argumentAst = astForNodeWithFunctionReference(unaryExpr.json("argument")) + + val node = callNode(unaryExpr, unaryExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(node, argAsts) + end astForUnaryExpression + + protected def astForSequenceExpression(seq: BabelNodeInfo): Ast = + val blockNode = createBlockNode(seq) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val sequenceExpressionAsts = createBlockStatementAsts(seq.json("expressions")) + setArgumentIndices(sequenceExpressionAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, sequenceExpressionAsts) + + protected def astForAwaitExpression(awaitExpr: BabelNodeInfo): Ast = + val node = + callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) + val argAsts = List(astForNodeWithFunctionReference(awaitExpr.json("argument"))) + callAst(node, argAsts) + + protected def astForArrayExpression(arrExpr: BabelNodeInfo): Ast = + val elements = Try(arrExpr.json("elements").arr).toOption.toList.flatten + if elements.isEmpty then + Ast( + callNode( + arrExpr, + s"${EcmaBuiltins.arrayFactory}()", + EcmaBuiltins.arrayFactory, + DispatchTypes.STATIC_DISPATCH + ) + ) + else + val blockNode = createBlockNode(arrExpr) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localTmpNode = newLocalNode(tmpName, Defines.Any).order(0) + val tmpArrayNode = identifierNode(arrExpr, tmpName) + diffGraph.addEdge(localAstParentStack.head, localTmpNode, EdgeTypes.AST) + scope.addVariableReference(tmpName, tmpArrayNode) + + val arrayCallNode = + callNode( + arrExpr, + s"${EcmaBuiltins.arrayFactory}()", + EcmaBuiltins.arrayFactory, + DispatchTypes.STATIC_DISPATCH + ) + + val lineNumber = arrExpr.lineNumber + val columnNumber = arrExpr.columnNumber + val assignmentCode = s"${localTmpNode.code} = ${arrayCallNode.code}" + val assignmentTmpArrayCallNode = + createAssignmentCallAst( + tmpArrayNode, + arrayCallNode, + assignmentCode, + lineNumber, + columnNumber + ) + + val elementAsts = elements.flatMap { + case element if !element.isNull => + val elementNodeInfo = createBabelNodeInfo(element) + val elementLineNumber = elementNodeInfo.lineNumber + val elementColumnNumber = elementNodeInfo.columnNumber + val elementCode = elementNodeInfo.code + val elementNode = elementNodeInfo.node match + case RestElement => + val arg1Ast = Ast(identifierNode(arrExpr, tmpName)) + astForSpreadOrRestElement(elementNodeInfo, Option(arg1Ast)) + case _ => + astForNodeWithFunctionReference(element) + + val pushCallNode = + callNode( + elementNodeInfo, + s"$tmpName.push($elementCode)", + "", + DispatchTypes.DYNAMIC_DISPATCH + ) + + val baseNode = identifierNode(elementNodeInfo, tmpName) + val memberNode = + createFieldIdentifierNode("push", elementLineNumber, elementColumnNumber) + val receiverNode = createFieldAccessCallAst( + baseNode, + memberNode, + elementLineNumber, + elementColumnNumber + ) + val thisPushNode = identifierNode(elementNodeInfo, tmpName) + + Option( + callAst( + pushCallNode, + List(elementNode), + receiver = Option(receiverNode), + base = Option(Ast(thisPushNode)) + ) + ) + case _ => None // skip + } + + val tmpArrayReturnNode = identifierNode(arrExpr, tmpName) + + scope.popScope() + localAstParentStack.pop() + + val blockChildrenAsts = + assignmentTmpArrayCallNode +: elementAsts :+ Ast(tmpArrayReturnNode) + setArgumentIndices(blockChildrenAsts) + blockAst(blockNode, blockChildrenAsts) + end if + end astForArrayExpression + + def astForTemplateExpression(templateExpr: BabelNodeInfo): Ast = + val argumentAst = astForNodeWithFunctionReference(templateExpr.json("quasi")) + val callName = code(templateExpr.json("tag")) + val callCode = s"$callName(${codeOf(argumentAst.nodes.head)})" + val templateExprCall = + callNode(templateExpr, callCode, callName, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(templateExprCall, argAsts) + + protected def astForObjectExpression(objExpr: BabelNodeInfo): Ast = + val blockNode = createBlockNode(objExpr) + + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localNode = newLocalNode(tmpName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + val propertiesAsts = objExpr.json("properties").arr.toList.map { property => + val nodeInfo = createBabelNodeInfo(property) + nodeInfo.node match + case SpreadElement | RestElement => + val arg1Ast = Ast(identifierNode(nodeInfo, tmpName)) + astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) + case _ => + val (lhsNode, rhsAst) = nodeInfo.node match + case ObjectMethod => + val keyName = + if hasKey(nodeInfo.json("key"), "name") then + nodeInfo.json("key")("name").str + else code(nodeInfo.json("key")) + val keyNode = createFieldIdentifierNode( + keyName, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + ( + keyNode, + astForFunctionDeclaration( + nodeInfo, + shouldCreateFunctionReference = true + ) + ) + case ObjectProperty => + val key = createBabelNodeInfo(nodeInfo.json("key")) + val keyName = key.node match + case Identifier if nodeInfo.json("computed").bool => + key.code + case _ if nodeInfo.json("computed").bool => + generateUnusedVariableName( + usedVariableNames, + "_computed_object_property" + ) + case _ => key.code + val keyNode = createFieldIdentifierNode( + keyName, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + val ast = astForNodeWithFunctionReference(nodeInfo.json("value")) + (keyNode, ast) + case _ => + // can't happen as per https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts#L573 + // just to make the compiler happy here. + ??? + + val leftHandSideTmpNode = identifierNode(nodeInfo, tmpName) + val leftHandSideFieldAccessAst = + createFieldAccessCallAst( + leftHandSideTmpNode, + lhsNode, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + + createAssignmentCallAst( + leftHandSideFieldAccessAst, + rhsAst, + s"$tmpName.${lhsNode.canonicalName} = ${codeOf(rhsAst.nodes.head)}", + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + end match + } + + val tmpNode = identifierNode(objExpr, tmpName) + + scope.popScope() + localAstParentStack.pop() + + val childrenAsts = propertiesAsts :+ Ast(tmpNode) + setArgumentIndices(childrenAsts) + blockAst(blockNode, childrenAsts) + end astForObjectExpression + + protected def astForTSSatisfiesExpression(satisfiesExpr: BabelNodeInfo): Ast = + // Ignores the type, i.e. `x satisfies T` is understood as `x`. + astForNode(satisfiesExpr.json("expression")) +end AstForExpressionsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala index f0fa2f7e..2e3eaa67 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -7,460 +7,560 @@ import io.appthreat.x2cpg.{Ast, ValidationMode} import io.appthreat.x2cpg.datastructures.Stack.* import io.appthreat.x2cpg.utils.NodeBuilders.{newBindingNode, newLocalNode} import io.shiftleft.codepropertygraph.generated.nodes.{Identifier as _, *} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, EvaluationStrategies, ModifierTypes} +import io.shiftleft.codepropertygraph.generated.{ + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + ModifierTypes +} import ujson.Value import scala.collection.mutable import scala.util.Try -trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - case class MethodAst(ast: Ast, methodNode: NewMethod, methodAst: Ast) - - private def handleRestInParameters( - elementNodeInfo: BabelNodeInfo, - paramNodeInfo: BabelNodeInfo, - paramName: String - ): Ast = { - val ast = astForNodeWithFunctionReferenceAndCall(elementNodeInfo.json("argument")) - val defaultName = codeForNodes(ast.nodes.toSeq) - val restName = nameForBabelNodeInfo(paramNodeInfo, defaultName) - ast.root match { - case Some(_: NewIdentifier) => - val keyNode = createFieldIdentifierNode(restName, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber) - val paramNode = identifierNode(elementNodeInfo, paramName) - val accessAst = - createFieldAccessCallAst(paramNode, keyNode, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber) - createAssignmentCallAst( - ast, - accessAst, - s"$restName = ${codeOf(accessAst.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - case _ => - val localParamNode = identifierNode(elementNodeInfo, restName) - createAssignmentCallAst( - Ast(localParamNode), - ast, - s"$restName = ${codeOf(ast.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - } - } - - private def handleParameters( - parameters: Seq[Value], - additionalBlockStatements: mutable.ArrayBuffer[Ast], - createLocals: Boolean = true - ): Seq[NewMethodParameterIn] = - withIndex(parameters) { case (param, index) => - val nodeInfo = createBabelNodeInfo(param) - val paramNode = nodeInfo.node match { - case RestElement => - val paramName = nodeInfo.code.replace("...", "") - val tpe = typeFor(nodeInfo) - if (createLocals) { - val localNode = newLocalNode(paramName, tpe).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - } - parameterInNode(nodeInfo, paramName, nodeInfo.code, index, true, EvaluationStrategies.BY_VALUE, Option(tpe)) - case AssignmentPattern => - val lhsElement = nodeInfo.json("left") - val rhsElement = nodeInfo.json("right") - val lhsNodeInfo = createBabelNodeInfo(lhsElement) - lhsNodeInfo.node match { - case ObjectPattern | ArrayPattern => - val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") - val param = parameterInNode( - nodeInfo, - paramName, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE - ) - scope.addVariable(paramName, param, MethodScope) - - val rhsAst = astForNodeWithFunctionReference(rhsElement) - additionalBlockStatements.addOne( - astForDeconstruction(lhsNodeInfo, rhsAst, nodeInfo.code, Option(paramName)) - ) - param - case _ => - additionalBlockStatements.addOne(convertParamWithDefault(nodeInfo)) - val tpe = typeFor(lhsNodeInfo) - parameterInNode( - lhsNodeInfo, - lhsNodeInfo.code, - nodeInfo.code, - index, - false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - } - case ArrayPattern => - val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") - val tpe = typeFor(nodeInfo) - val param = parameterInNode( - nodeInfo, - paramName, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - additionalBlockStatements.addAll(nodeInfo.json("elements").arr.toList.map { - case element if !element.isNull => - val elementNodeInfo = createBabelNodeInfo(element) - elementNodeInfo.node match { - case Identifier => - val elemName = code(elementNodeInfo.json) - val tpe = typeFor(elementNodeInfo) - val localParamNode = identifierNode(elementNodeInfo, elemName) - localParamNode.typeFullName = tpe - - val localNode = newLocalNode(elemName, tpe).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(elemName, localNode, MethodScope) - - val paramNode = identifierNode(elementNodeInfo, paramName) - scope.addVariableReference(paramName, paramNode) - - val keyNode = - createFieldIdentifierNode(elemName, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber) - val accessAst = +trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + case class MethodAst(ast: Ast, methodNode: NewMethod, methodAst: Ast) + + private def handleRestInParameters( + elementNodeInfo: BabelNodeInfo, + paramNodeInfo: BabelNodeInfo, + paramName: String + ): Ast = + val ast = astForNodeWithFunctionReferenceAndCall(elementNodeInfo.json("argument")) + val defaultName = codeForNodes(ast.nodes.toSeq) + val restName = nameForBabelNodeInfo(paramNodeInfo, defaultName) + ast.root match + case Some(_: NewIdentifier) => + val keyNode = createFieldIdentifierNode( + restName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val paramNode = identifierNode(elementNodeInfo, paramName) + val accessAst = createFieldAccessCallAst( paramNode, keyNode, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber ) - createAssignmentCallAst( - Ast(localParamNode), - accessAst, - s"$elemName = ${codeOf(accessAst.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - case RestElement => handleRestInParameters(elementNodeInfo, nodeInfo, paramName) - case _ => astForNodeWithFunctionReference(elementNodeInfo.json) - } - case _ => Ast() - }) - param - case ObjectPattern => - val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") - // Handle de-structured parameters declared as `{ username: string; password: string; }` - val typeDecl = astForTypeAlias(nodeInfo) - val tpe = typeDecl.root.collect { case t: NewTypeDecl => t.fullName }.getOrElse(typeFor(nodeInfo)) - val param = parameterInNode( - nodeInfo, - paramName, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - Ast.storeInDiffGraph(typeDecl, diffGraph) - scope.addVariable(paramName, param, MethodScope) - - additionalBlockStatements.addAll(nodeInfo.json("properties").arr.toList.map { element => - val elementNodeInfo = createBabelNodeInfo(element) - elementNodeInfo.node match { - case ObjectProperty => - val elemName = code(elementNodeInfo.json("key")) - val tpe = typeFor(elementNodeInfo) - val localParamNode = identifierNode(elementNodeInfo, elemName) - localParamNode.typeFullName = tpe - - val localNode = newLocalNode(elemName, tpe).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(elemName, localNode, MethodScope) - - val paramNode = identifierNode(elementNodeInfo, paramName) - scope.addVariableReference(paramName, paramNode) - - val keyNode = - createFieldIdentifierNode(elemName, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber) - val accessAst = - createFieldAccessCallAst(paramNode, keyNode, elementNodeInfo.lineNumber, elementNodeInfo.columnNumber) - val assignmentCallAst = createAssignmentCallAst( - Ast(localParamNode), + createAssignmentCallAst( + ast, accessAst, - s"$elemName = ${codeOf(accessAst.nodes.head)}", + s"$restName = ${codeOf(accessAst.nodes.head)}", elementNodeInfo.lineNumber, elementNodeInfo.columnNumber ) - // Handle identifiers referring to locals created by destructured parameters - assignmentCallAst.nodes - .collect { case i: NewIdentifier if localNode.name == i.name => i } - .map { i => assignmentCallAst.withRefEdge(i, localNode) } - .reduce(_ merge _) - case RestElement => handleRestInParameters(elementNodeInfo, nodeInfo, paramName) - case _ => astForNodeWithFunctionReference(elementNodeInfo.json) + case _ => + val localParamNode = identifierNode(elementNodeInfo, restName) + createAssignmentCallAst( + Ast(localParamNode), + ast, + s"$restName = ${codeOf(ast.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + end match + end handleRestInParameters + + private def handleParameters( + parameters: Seq[Value], + additionalBlockStatements: mutable.ArrayBuffer[Ast], + createLocals: Boolean = true + ): Seq[NewMethodParameterIn] = + withIndex(parameters) { case (param, index) => + val nodeInfo = createBabelNodeInfo(param) + val paramNode = nodeInfo.node match + case RestElement => + val paramName = nodeInfo.code.replace("...", "") + val tpe = typeFor(nodeInfo) + if createLocals then + val localNode = newLocalNode(paramName, tpe).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + true, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + case AssignmentPattern => + val lhsElement = nodeInfo.json("left") + val rhsElement = nodeInfo.json("right") + val lhsNodeInfo = createBabelNodeInfo(lhsElement) + lhsNodeInfo.node match + case ObjectPattern | ArrayPattern => + val paramName = + generateUnusedVariableName(usedVariableNames, s"param$index") + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE + ) + scope.addVariable(paramName, param, MethodScope) + + val rhsAst = astForNodeWithFunctionReference(rhsElement) + additionalBlockStatements.addOne( + astForDeconstruction( + lhsNodeInfo, + rhsAst, + nodeInfo.code, + Option(paramName) + ) + ) + param + case _ => + additionalBlockStatements.addOne(convertParamWithDefault(nodeInfo)) + val tpe = typeFor(lhsNodeInfo) + parameterInNode( + lhsNodeInfo, + lhsNodeInfo.code, + nodeInfo.code, + index, + false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + end match + case ArrayPattern => + val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") + val tpe = typeFor(nodeInfo) + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + additionalBlockStatements.addAll(nodeInfo.json("elements").arr.toList.map { + case element if !element.isNull => + val elementNodeInfo = createBabelNodeInfo(element) + elementNodeInfo.node match + case Identifier => + val elemName = code(elementNodeInfo.json) + val tpe = typeFor(elementNodeInfo) + val localParamNode = identifierNode(elementNodeInfo, elemName) + localParamNode.typeFullName = tpe + + val localNode = newLocalNode(elemName, tpe).order(0) + diffGraph.addEdge( + localAstParentStack.head, + localNode, + EdgeTypes.AST + ) + scope.addVariable(elemName, localNode, MethodScope) + + val paramNode = identifierNode(elementNodeInfo, paramName) + scope.addVariableReference(paramName, paramNode) + + val keyNode = + createFieldIdentifierNode( + elemName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val accessAst = + createFieldAccessCallAst( + paramNode, + keyNode, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + createAssignmentCallAst( + Ast(localParamNode), + accessAst, + s"$elemName = ${codeOf(accessAst.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + case RestElement => + handleRestInParameters(elementNodeInfo, nodeInfo, paramName) + case _ => astForNodeWithFunctionReference(elementNodeInfo.json) + end match + case _ => Ast() + }) + param + case ObjectPattern => + val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") + // Handle de-structured parameters declared as `{ username: string; password: string; }` + val typeDecl = astForTypeAlias(nodeInfo) + val tpe = typeDecl.root.collect { case t: NewTypeDecl => t.fullName }.getOrElse( + typeFor(nodeInfo) + ) + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + Ast.storeInDiffGraph(typeDecl, diffGraph) + scope.addVariable(paramName, param, MethodScope) + + additionalBlockStatements.addAll(nodeInfo.json("properties").arr.toList.map { + element => + val elementNodeInfo = createBabelNodeInfo(element) + elementNodeInfo.node match + case ObjectProperty => + val elemName = code(elementNodeInfo.json("key")) + val tpe = typeFor(elementNodeInfo) + val localParamNode = identifierNode(elementNodeInfo, elemName) + localParamNode.typeFullName = tpe + + val localNode = newLocalNode(elemName, tpe).order(0) + diffGraph.addEdge( + localAstParentStack.head, + localNode, + EdgeTypes.AST + ) + scope.addVariable(elemName, localNode, MethodScope) + + val paramNode = identifierNode(elementNodeInfo, paramName) + scope.addVariableReference(paramName, paramNode) + + val keyNode = + createFieldIdentifierNode( + elemName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val accessAst = + createFieldAccessCallAst( + paramNode, + keyNode, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val assignmentCallAst = createAssignmentCallAst( + Ast(localParamNode), + accessAst, + s"$elemName = ${codeOf(accessAst.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + // Handle identifiers referring to locals created by destructured parameters + assignmentCallAst.nodes + .collect { + case i: NewIdentifier if localNode.name == i.name => i + } + .map { i => assignmentCallAst.withRefEdge(i, localNode) } + .reduce(_ merge _) + case RestElement => + handleRestInParameters(elementNodeInfo, nodeInfo, paramName) + case _ => astForNodeWithFunctionReference(elementNodeInfo.json) + end match + }) + param + case Identifier => + // Handle types declared as `credentials: { username: string; password: string; }` + val tpe = + Try(createBabelNodeInfo(nodeInfo.json("typeAnnotation")("typeAnnotation"))) + .map(x => + x.node match + case TSTypeLiteral => + val typeDecl = astForTypeAlias(x) + Ast.storeInDiffGraph(typeDecl, diffGraph) + typeDecl.root.collect { case t: NewTypeDecl => + t.fullName + }.getOrElse(typeFor(nodeInfo)) + case _ => typeFor(nodeInfo) + ) + .getOrElse(typeFor(nodeInfo)) + + val name = nodeInfo.json("name").str + val node = + parameterInNode( + nodeInfo, + name, + nodeInfo.code, + index, + false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + scope.addVariable(name, node, MethodScope) + node + case TSParameterProperty => + val unpackedParam = createBabelNodeInfo(nodeInfo.json("parameter")) + val tpe = typeFor(unpackedParam) + + val name = unpackedParam.node match + case AssignmentPattern => + createBabelNodeInfo(unpackedParam.json("left")).code + case _ => unpackedParam.json("name").str + val node = + parameterInNode( + nodeInfo, + name, + nodeInfo.code, + index, + false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + scope.addVariable(name, node, MethodScope) + node + case _ => + val tpe = typeFor(nodeInfo) + val node = + parameterInNode( + nodeInfo, + nodeInfo.code, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + scope.addVariable(nodeInfo.code, node, MethodScope) + node + val decoratorAsts = astsForDecorators(nodeInfo) + decoratorAsts.foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(paramNode, _, EdgeTypes.AST)) } - }) - param - case Identifier => - // Handle types declared as `credentials: { username: string; password: string; }` - val tpe = Try(createBabelNodeInfo(nodeInfo.json("typeAnnotation")("typeAnnotation"))) - .map(x => - x.node match { - case TSTypeLiteral => - val typeDecl = astForTypeAlias(x) - Ast.storeInDiffGraph(typeDecl, diffGraph) - typeDecl.root.collect { case t: NewTypeDecl => t.fullName }.getOrElse(typeFor(nodeInfo)) - case _ => typeFor(nodeInfo) - } + paramNode + } + + private def convertParamWithDefault(element: BabelNodeInfo): Ast = + val lhsElement = element.json("left") + val rhsElement = element.json("right") + + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsAst = astForNode(lhsElement) + + val testAst = + val keyNode = identifierNode(element, codeOf(lhsAst.nodes.head)) + val voidCallNode = + callNode(element, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) + val equalsCallAst = createEqualsCallAst( + Ast(keyNode), + Ast(voidCallNode), + element.lineNumber, + element.columnNumber ) - .getOrElse(typeFor(nodeInfo)) - - val name = nodeInfo.json("name").str - val node = - parameterInNode(nodeInfo, name, nodeInfo.code, index, false, EvaluationStrategies.BY_VALUE, Option(tpe)) - scope.addVariable(name, node, MethodScope) - node - case TSParameterProperty => - val unpackedParam = createBabelNodeInfo(nodeInfo.json("parameter")) - val tpe = typeFor(unpackedParam) - - val name = unpackedParam.node match { - case AssignmentPattern => createBabelNodeInfo(unpackedParam.json("left")).code - case _ => unpackedParam.json("name").str - } - val node = - parameterInNode(nodeInfo, name, nodeInfo.code, index, false, EvaluationStrategies.BY_VALUE, Option(tpe)) - scope.addVariable(name, node, MethodScope) - node - case _ => - val tpe = typeFor(nodeInfo) - val node = - parameterInNode( - nodeInfo, - nodeInfo.code, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE, - Option(tpe) + equalsCallAst + val falseNode = identifierNode(element, codeOf(lhsAst.nodes.head)) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + Ast(falseNode), + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertParamWithDefault + + private def getParentTypeDecl: NewTypeDecl = + methodAstParentStack.collectFirst { case n: NewTypeDecl => n }.getOrElse(rootTypeDecl.head) + + protected def astForTSDeclareFunction(func: BabelNodeInfo): Ast = + val functionNode = createMethodDefinitionNode(func) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(getParentTypeDecl, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) + addModifier(functionNode, func.json) + Ast(functionNode) + + protected def createMethodDefinitionNode( + func: BabelNodeInfo, + methodBlockContent: List[Ast] = List.empty + ): NewMethod = + val (methodName, methodFullName) = calcMethodNameAndFullName(func) + val methodNode_ = + methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) + val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) + methodAstParentStack.push(methodNode_) + + val thisNode = + parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisNode, MethodScope) + + val paramNodes = if hasKey(func.json, "parameters") then + handleParameters( + func.json("parameters").arr.toSeq, + mutable.ArrayBuffer.empty[Ast], + createLocals = false + ) + else + handleParameters( + func.json("params").arr.toSeq, + mutable.ArrayBuffer.empty[Ast], + createLocals = false ) - scope.addVariable(nodeInfo.code, node, MethodScope) - node - } - val decoratorAsts = astsForDecorators(nodeInfo) - decoratorAsts.foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(paramNode, _, EdgeTypes.AST)) - } - paramNode - } - - private def convertParamWithDefault(element: BabelNodeInfo): Ast = { - val lhsElement = element.json("left") - val rhsElement = element.json("right") - - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsAst = astForNode(lhsElement) - - val testAst = { - val keyNode = identifierNode(element, codeOf(lhsAst.nodes.head)) - val voidCallNode = - callNode(element, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) - val equalsCallAst = createEqualsCallAst(Ast(keyNode), Ast(voidCallNode), element.lineNumber, element.columnNumber) - equalsCallAst - } - val falseNode = identifierNode(element, codeOf(lhsAst.nodes.head)) - val ternaryNodeAst = - createTernaryCallAst(testAst, rhsAst, Ast(falseNode), element.lineNumber, element.columnNumber) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", - element.lineNumber, - element.columnNumber - ) - } - - private def getParentTypeDecl: NewTypeDecl = - methodAstParentStack.collectFirst { case n: NewTypeDecl => n }.getOrElse(rootTypeDecl.head) - - protected def astForTSDeclareFunction(func: BabelNodeInfo): Ast = { - val functionNode = createMethodDefinitionNode(func) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(getParentTypeDecl, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) - addModifier(functionNode, func.json) - Ast(functionNode) - } - - protected def createMethodDefinitionNode( - func: BabelNodeInfo, - methodBlockContent: List[Ast] = List.empty - ): NewMethod = { - val (methodName, methodFullName) = calcMethodNameAndFullName(func) - val methodNode_ = methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) - val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) - methodAstParentStack.push(methodNode_) - - val thisNode = - parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisNode, MethodScope) - - val paramNodes = if (hasKey(func.json, "parameters")) { - handleParameters(func.json("parameters").arr.toSeq, mutable.ArrayBuffer.empty[Ast], createLocals = false) - } else { - handleParameters(func.json("params").arr.toSeq, mutable.ArrayBuffer.empty[Ast], createLocals = false) - } - - val methodReturnNode = createMethodReturnNode(func) - - methodAstParentStack.pop() - - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst( - func, - methodNode_, - methodAstParentStack.head, - methodName, - methodFullName, - parserResult.filename - ) - - val mAst = if (methodBlockContent.isEmpty) { - methodStubAst(methodNode_, thisNode +: paramNodes, methodReturnNode, List(virtualModifierNode)) - } else { - setArgumentIndices(methodBlockContent) - val bodyAst = blockAst(NewBlock(), methodBlockContent) - methodAstWithAnnotations( - methodNode_, - (thisNode +: paramNodes).map(Ast(_)), - bodyAst, - methodReturnNode, - annotations = astsForDecorators(func) - ) - } - - Ast.storeInDiffGraph(mAst, diffGraph) - Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) - diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) - - methodNode_ - } - - protected def createMethodAstAndNode( - func: BabelNodeInfo, - shouldCreateFunctionReference: Boolean = false, - shouldCreateAssignmentCall: Boolean = false, - methodBlockContent: List[Ast] = List.empty - ): MethodAst = { - val (methodName, methodFullName) = calcMethodNameAndFullName(func) - val methodRefNode_ = if (!shouldCreateFunctionReference) { - None - } else { Option(methodRefNode(func, methodName, methodFullName, methodFullName)) } - - val callAst = if (shouldCreateAssignmentCall && shouldCreateFunctionReference) { - val idNode = identifierNode(func, methodName) - val idLocal = newLocalNode(methodName, methodFullName).order(0) - diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) - scope.addVariable(methodName, idLocal, BlockScope) - scope.addVariableReference(methodName, idNode) - val code = s"function $methodName = ${func.code}" - val assignment = createAssignmentCallAst(idNode, methodRefNode_.get, code, func.lineNumber, func.columnNumber) - assignment - } else { - Ast() - } - - val methodNode_ = methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) - val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) - - methodAstParentStack.push(methodNode_) - - val bodyJson = func.json("body") - val bodyNodeInfo = createBabelNodeInfo(bodyJson) - val blockNode = createBlockNode(bodyNodeInfo) - val additionalBlockStatements = mutable.ArrayBuffer.empty[Ast] - - val capturingRefNode = - if (shouldCreateFunctionReference) { - methodRefNode_ - } else { - typeRefIdStack.headOption - } - scope.pushNewMethodScope(methodFullName, methodName, blockNode, capturingRefNode) - localAstParentStack.push(blockNode) - - val thisNode = - parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisNode, MethodScope) - - val paramNodes = handleParameters(func.json("params").arr.toSeq, additionalBlockStatements) - - val bodyStmtAsts = func.node match { - case ArrowFunctionExpression => - bodyNodeInfo.node match { - case BlockStatement => - // when body contains more than one statement, use bodyJson("body")) to avoid double Block node - createBlockStatementAsts(bodyJson("body")) - case _ => - // when body is just one expression like const foo = () => 42, generate a Return node - val retCode = bodyNodeInfo.code.stripSuffix(";") - returnAst(returnNode(bodyNodeInfo, retCode), List(astForNodeWithFunctionReference(bodyJson))) :: Nil - } - case _ => createBlockStatementAsts(bodyJson("body")) - } - val methodBlockChildren = methodBlockContent ++ additionalBlockStatements.toList ++ bodyStmtAsts - setArgumentIndices(methodBlockChildren) - - val methodReturnNode = createMethodReturnNode(func) - - localAstParentStack.pop() - scope.popScope() - methodAstParentStack.pop() - - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst( - func, - methodNode_, - methodAstParentStack.head, - methodName, - methodFullName, - parserResult.filename - ) - - val mAst = - methodAstWithAnnotations( - methodNode_, - (thisNode +: paramNodes).map(Ast(_)), - blockAst(blockNode, methodBlockChildren), - methodReturnNode, - List(virtualModifierNode), - astsForDecorators(func) - ) - Ast.storeInDiffGraph(mAst, diffGraph) - Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) - diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) - - methodRefNode_ match { - case Some(ref) if callAst.nodes.isEmpty => - MethodAst(Ast(ref), methodNode_, mAst) - case _ => - MethodAst(callAst, methodNode_, mAst) - } - } - - protected def astForFunctionDeclaration( - func: BabelNodeInfo, - shouldCreateFunctionReference: Boolean = false, - shouldCreateAssignmentCall: Boolean = false - ): Ast = createMethodAstAndNode(func, shouldCreateFunctionReference, shouldCreateAssignmentCall).ast -} + val methodReturnNode = createMethodReturnNode(func) + + methodAstParentStack.pop() + + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + func, + methodNode_, + methodAstParentStack.head, + methodName, + methodFullName, + parserResult.filename + ) + + val mAst = if methodBlockContent.isEmpty then + methodStubAst( + methodNode_, + thisNode +: paramNodes, + methodReturnNode, + List(virtualModifierNode) + ) + else + setArgumentIndices(methodBlockContent) + val bodyAst = blockAst(NewBlock(), methodBlockContent) + methodAstWithAnnotations( + methodNode_, + (thisNode +: paramNodes).map(Ast(_)), + bodyAst, + methodReturnNode, + annotations = astsForDecorators(func) + ) + + Ast.storeInDiffGraph(mAst, diffGraph) + Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) + diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) + + methodNode_ + end createMethodDefinitionNode + + protected def createMethodAstAndNode( + func: BabelNodeInfo, + shouldCreateFunctionReference: Boolean = false, + shouldCreateAssignmentCall: Boolean = false, + methodBlockContent: List[Ast] = List.empty + ): MethodAst = + val (methodName, methodFullName) = calcMethodNameAndFullName(func) + val methodRefNode_ = if !shouldCreateFunctionReference then + None + else Option(methodRefNode(func, methodName, methodFullName, methodFullName)) + + val callAst = if shouldCreateAssignmentCall && shouldCreateFunctionReference then + val idNode = identifierNode(func, methodName) + val idLocal = newLocalNode(methodName, methodFullName).order(0) + diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) + scope.addVariable(methodName, idLocal, BlockScope) + scope.addVariableReference(methodName, idNode) + val code = s"function $methodName = ${func.code}" + val assignment = createAssignmentCallAst( + idNode, + methodRefNode_.get, + code, + func.lineNumber, + func.columnNumber + ) + assignment + else + Ast() + + val methodNode_ = + methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) + val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) + + methodAstParentStack.push(methodNode_) + + val bodyJson = func.json("body") + val bodyNodeInfo = createBabelNodeInfo(bodyJson) + val blockNode = createBlockNode(bodyNodeInfo) + val additionalBlockStatements = mutable.ArrayBuffer.empty[Ast] + + val capturingRefNode = + if shouldCreateFunctionReference then + methodRefNode_ + else + typeRefIdStack.headOption + scope.pushNewMethodScope(methodFullName, methodName, blockNode, capturingRefNode) + localAstParentStack.push(blockNode) + + val thisNode = + parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisNode, MethodScope) + + val paramNodes = handleParameters(func.json("params").arr.toSeq, additionalBlockStatements) + + val bodyStmtAsts = func.node match + case ArrowFunctionExpression => + bodyNodeInfo.node match + case BlockStatement => + // when body contains more than one statement, use bodyJson("body")) to avoid double Block node + createBlockStatementAsts(bodyJson("body")) + case _ => + // when body is just one expression like const foo = () => 42, generate a Return node + val retCode = bodyNodeInfo.code.stripSuffix(";") + returnAst( + returnNode(bodyNodeInfo, retCode), + List(astForNodeWithFunctionReference(bodyJson)) + ) :: Nil + case _ => createBlockStatementAsts(bodyJson("body")) + val methodBlockChildren = + methodBlockContent ++ additionalBlockStatements.toList ++ bodyStmtAsts + setArgumentIndices(methodBlockChildren) + + val methodReturnNode = createMethodReturnNode(func) + + localAstParentStack.pop() + scope.popScope() + methodAstParentStack.pop() + + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + func, + methodNode_, + methodAstParentStack.head, + methodName, + methodFullName, + parserResult.filename + ) + + val mAst = + methodAstWithAnnotations( + methodNode_, + (thisNode +: paramNodes).map(Ast(_)), + blockAst(blockNode, methodBlockChildren), + methodReturnNode, + List(virtualModifierNode), + astsForDecorators(func) + ) + Ast.storeInDiffGraph(mAst, diffGraph) + Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) + diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) + + methodRefNode_ match + case Some(ref) if callAst.nodes.isEmpty => + MethodAst(Ast(ref), methodNode_, mAst) + case _ => + MethodAst(callAst, methodNode_, mAst) + end createMethodAstAndNode + + protected def astForFunctionDeclaration( + func: BabelNodeInfo, + shouldCreateFunctionReference: Boolean = false, + shouldCreateAssignmentCall: Boolean = false + ): Ast = + createMethodAstAndNode(func, shouldCreateFunctionReference, shouldCreateAssignmentCall).ast +end AstForFunctionsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala index 739923a5..3edb9e8b 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala @@ -5,89 +5,102 @@ import io.appthreat.jssrc2cpg.passes.Defines import io.appthreat.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} -trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - protected def astForIdentifier(ident: BabelNodeInfo, typeFullName: Option[String] = None): Ast = { - val name = ident.json("name").str - val identNode = identifierNode(ident, name) - val tpe = typeFullName match { - case Some(Defines.Any) => typeFor(ident) - case Some(otherType) => otherType - case None => typeFor(ident) - } - identNode.typeFullName = tpe - scope.addVariableReference(name, identNode) - Ast(identNode) - } - - protected def astForSuperKeyword(superKeyword: BabelNodeInfo): Ast = - Ast(identifierNode(superKeyword, "super")) - - protected def astForImportKeyword(importKeyword: BabelNodeInfo): Ast = - Ast(identifierNode(importKeyword, "import")) - - protected def astForNullLiteral(nullLiteral: BabelNodeInfo): Ast = - Ast(literalNode(nullLiteral, nullLiteral.code, Option(Defines.Null))) - - protected def astForStringLiteral(stringLiteral: BabelNodeInfo): Ast = { - val code = s"\"${stringLiteral.json("value").str}\"" - Ast(literalNode(stringLiteral, code, Option(Defines.String))) - } - - protected def astForPrivateName(privateName: BabelNodeInfo): Ast = - astForIdentifier(createBabelNodeInfo(privateName.json("id"))) - - protected def astForSpreadOrRestElement(spreadElement: BabelNodeInfo, arg1Ast: Option[Ast] = None): Ast = { - val ast = astForNodeWithFunctionReference(spreadElement.json("argument")) - val callNode_ = - callNode(spreadElement, spreadElement.code, ".spread", DispatchTypes.STATIC_DISPATCH) - callAst(callNode_, arg1Ast.toList :+ ast) - } - - protected def astForTemplateElement(templateElement: BabelNodeInfo): Ast = - Ast(literalNode(templateElement, s"\"${templateElement.json("value")("raw").str}\"", Option(Defines.String))) - - protected def astForRegExpLiteral(regExpLiteral: BabelNodeInfo): Ast = - Ast(literalNode(regExpLiteral, regExpLiteral.code, Option(Defines.String))) - - protected def astForRegexLiteral(regexLiteral: BabelNodeInfo): Ast = - Ast(literalNode(regexLiteral, regexLiteral.code, Option(Defines.String))) - - protected def astForNumberLiteral(numberLiteral: BabelNodeInfo): Ast = - Ast(literalNode(numberLiteral, numberLiteral.code, Option(Defines.Number))) - - protected def astForNumericLiteral(numericLiteral: BabelNodeInfo): Ast = - Ast(literalNode(numericLiteral, numericLiteral.code, Option(Defines.Number))) - - protected def astForDecimalLiteral(decimalLiteral: BabelNodeInfo): Ast = - Ast(literalNode(decimalLiteral, decimalLiteral.code, Option(Defines.Number))) - - protected def astForBigIntLiteral(bigIntLiteral: BabelNodeInfo): Ast = - Ast(literalNode(bigIntLiteral, bigIntLiteral.code, Option(Defines.Number))) - - protected def astForBooleanLiteral(booleanLiteral: BabelNodeInfo): Ast = - Ast(literalNode(booleanLiteral, booleanLiteral.code, Option(Defines.Boolean))) - - protected def astForTemplateLiteral(templateLiteral: BabelNodeInfo): Ast = { - val expressions = templateLiteral.json("expressions").arr.toList - val quasis = templateLiteral.json("quasis").arr.toList.filterNot(_("tail").bool) - val quasisTail = templateLiteral.json("quasis").arr.toList.filter(_("tail").bool).head - - if (expressions.isEmpty && quasis.isEmpty) { - astForTemplateElement(createBabelNodeInfo(quasisTail)) - } else { - val callName = Operators.formatString - val argsCodes = expressions.zip(quasis).flatMap { case (expression, quasi) => - List(s"\"${quasi("value")("raw").str}\"", code(expression)) - } - val callCode = s"$callName${(argsCodes :+ s"\"${quasisTail("value")("raw").str}\"").mkString("(", ", ", ")")}" - val templateCall = - callNode(templateLiteral, callCode, callName, DispatchTypes.STATIC_DISPATCH) - val argumentAsts = expressions.zip(quasis).flatMap { case (expression, quasi) => - List(astForNodeWithFunctionReference(quasi), astForNodeWithFunctionReference(expression)) - } - val argAsts = argumentAsts :+ astForNodeWithFunctionReference(quasisTail) - callAst(templateCall, argAsts) - } - } -} +trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForIdentifier(ident: BabelNodeInfo, typeFullName: Option[String] = None): Ast = + val name = ident.json("name").str + val identNode = identifierNode(ident, name) + val tpe = typeFullName match + case Some(Defines.Any) => typeFor(ident) + case Some(otherType) => otherType + case None => typeFor(ident) + identNode.typeFullName = tpe + scope.addVariableReference(name, identNode) + Ast(identNode) + + protected def astForSuperKeyword(superKeyword: BabelNodeInfo): Ast = + Ast(identifierNode(superKeyword, "super")) + + protected def astForImportKeyword(importKeyword: BabelNodeInfo): Ast = + Ast(identifierNode(importKeyword, "import")) + + protected def astForNullLiteral(nullLiteral: BabelNodeInfo): Ast = + Ast(literalNode(nullLiteral, nullLiteral.code, Option(Defines.Null))) + + protected def astForStringLiteral(stringLiteral: BabelNodeInfo): Ast = + val code = s"\"${stringLiteral.json("value").str}\"" + Ast(literalNode(stringLiteral, code, Option(Defines.String))) + + protected def astForPrivateName(privateName: BabelNodeInfo): Ast = + astForIdentifier(createBabelNodeInfo(privateName.json("id"))) + + protected def astForSpreadOrRestElement( + spreadElement: BabelNodeInfo, + arg1Ast: Option[Ast] = None + ): Ast = + val ast = astForNodeWithFunctionReference(spreadElement.json("argument")) + val callNode_ = + callNode( + spreadElement, + spreadElement.code, + ".spread", + DispatchTypes.STATIC_DISPATCH + ) + callAst(callNode_, arg1Ast.toList :+ ast) + + protected def astForTemplateElement(templateElement: BabelNodeInfo): Ast = + Ast(literalNode( + templateElement, + s"\"${templateElement.json("value")("raw").str}\"", + Option(Defines.String) + )) + + protected def astForRegExpLiteral(regExpLiteral: BabelNodeInfo): Ast = + Ast(literalNode(regExpLiteral, regExpLiteral.code, Option(Defines.String))) + + protected def astForRegexLiteral(regexLiteral: BabelNodeInfo): Ast = + Ast(literalNode(regexLiteral, regexLiteral.code, Option(Defines.String))) + + protected def astForNumberLiteral(numberLiteral: BabelNodeInfo): Ast = + Ast(literalNode(numberLiteral, numberLiteral.code, Option(Defines.Number))) + + protected def astForNumericLiteral(numericLiteral: BabelNodeInfo): Ast = + Ast(literalNode(numericLiteral, numericLiteral.code, Option(Defines.Number))) + + protected def astForDecimalLiteral(decimalLiteral: BabelNodeInfo): Ast = + Ast(literalNode(decimalLiteral, decimalLiteral.code, Option(Defines.Number))) + + protected def astForBigIntLiteral(bigIntLiteral: BabelNodeInfo): Ast = + Ast(literalNode(bigIntLiteral, bigIntLiteral.code, Option(Defines.Number))) + + protected def astForBooleanLiteral(booleanLiteral: BabelNodeInfo): Ast = + Ast(literalNode(booleanLiteral, booleanLiteral.code, Option(Defines.Boolean))) + + protected def astForTemplateLiteral(templateLiteral: BabelNodeInfo): Ast = + val expressions = templateLiteral.json("expressions").arr.toList + val quasis = templateLiteral.json("quasis").arr.toList.filterNot(_("tail").bool) + val quasisTail = templateLiteral.json("quasis").arr.toList.filter(_("tail").bool).head + + if expressions.isEmpty && quasis.isEmpty then + astForTemplateElement(createBabelNodeInfo(quasisTail)) + else + val callName = Operators.formatString + val argsCodes = expressions.zip(quasis).flatMap { case (expression, quasi) => + List(s"\"${quasi("value")("raw").str}\"", code(expression)) + } + val callCode = + s"$callName${(argsCodes :+ s"\"${quasisTail("value")("raw").str}\"").mkString("(", ", ", ")")}" + val templateCall = + callNode(templateLiteral, callCode, callName, DispatchTypes.STATIC_DISPATCH) + val argumentAsts = expressions.zip(quasis).flatMap { case (expression, quasi) => + List( + astForNodeWithFunctionReference(quasi), + astForNodeWithFunctionReference(expression) + ) + } + val argAsts = argumentAsts :+ astForNodeWithFunctionReference(quasisTail) + callAst(templateCall, argAsts) + end if + end astForTemplateLiteral +end AstForPrimitivesCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala index 64489b03..6c403305 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala @@ -15,905 +15,1066 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewJumpTarget import ujson.Obj import ujson.Value -trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - /** Sort all block statements with the following result: - * - all function declarations go first - * - all type aliases that are not plain type references go last - * - all remaining type aliases go before that - * - all remaining statements go second - * - * We do this to get TypeDecls created at the right spot so we can make use of them for the type aliases. - */ - private def sortBlockStatements(blockStatements: List[BabelNodeInfo]): List[BabelNodeInfo] = - blockStatements.sortBy { nodeInfo => - nodeInfo.node match { - case ImportDeclaration => 0 - case FunctionDeclaration => 1 - case DeclareTypeAlias if isPlainTypeAlias(nodeInfo) => 4 - case TypeAlias if isPlainTypeAlias(nodeInfo) => 4 - case TSTypeAliasDeclaration if isPlainTypeAlias(nodeInfo) => 4 - case DeclareTypeAlias => 3 - case TypeAlias => 3 - case TSTypeAliasDeclaration => 3 - case _ => 2 - } - } - - protected def createBlockStatementAsts(json: Value): List[Ast] = { - val blockStmts = sortBlockStatements(json.arr.map(createBabelNodeInfo).toList) - val blockAsts = blockStmts.map(stmt => astForNodeWithFunctionReferenceAndCall(stmt.json)) - setArgumentIndices(blockAsts) - blockAsts - } - - protected def astForWithStatement(withStatement: BabelNodeInfo): Ast = { - val blockNode = createBlockNode(withStatement) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val objectAst = astForNodeWithFunctionReferenceAndCall(withStatement.json("object")) - val bodyNodeInfo = createBabelNodeInfo(withStatement.json("body")) - val bodyAsts = bodyNodeInfo.node match { - case BlockStatement => createBlockStatementAsts(bodyNodeInfo.json("body")) - case _ => List(astForNodeWithFunctionReferenceAndCall(bodyNodeInfo.json)) - } - val blockStatementAsts = objectAst +: bodyAsts - setArgumentIndices(blockStatementAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, blockStatementAsts) - } - - protected def astForBlockStatement(block: BabelNodeInfo): Ast = { - val blockNode = createBlockNode(block) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val blockStatementAsts = createBlockStatementAsts(block.json("body")) - setArgumentIndices(blockStatementAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, blockStatementAsts) - } - - protected def astForReturnStatement(ret: BabelNodeInfo): Ast = { - val retCode = ret.code.stripSuffix(";") - val retNode = returnNode(ret, retCode) - safeObj(ret.json, "argument") - .map { argument => - val argAst = astForNodeWithFunctionReference(Obj(argument)) - returnAst(retNode, List(argAst)) - } - .getOrElse(Ast(retNode)) - } - - private def astForCatchClause(catchClause: BabelNodeInfo): Ast = - astForNodeWithFunctionReference(catchClause.json("body")) - - protected def astForTryStatement(tryStmt: BabelNodeInfo): Ast = { - val tryNode = createControlStructureNode(tryStmt, ControlStructureTypes.TRY) - val bodyAst = astForNodeWithFunctionReference(tryStmt.json("block")) - val catchAst = safeObj(tryStmt.json, "handler") - .map { handler => - astForCatchClause(createBabelNodeInfo(Obj(handler))) - } - .getOrElse(Ast()) - val finalizerAst = safeObj(tryStmt.json, "finalizer") - .map { finalizer => - astForNodeWithFunctionReference(Obj(finalizer)) - } - .getOrElse(Ast()) - // The semantics of try statement children is defined by there order value. - // Thus we set the here explicitly and do not rely on the usual consecutive - // ordering. - setOrderExplicitly(bodyAst, 1) - setOrderExplicitly(catchAst, 2) - setOrderExplicitly(finalizerAst, 3) - Ast(tryNode).withChildren(List(bodyAst, catchAst, finalizerAst)) - } - - def astForIfStatement(ifStmt: BabelNodeInfo): Ast = { - val ifNode = createControlStructureNode(ifStmt, ControlStructureTypes.IF) - val testAst = astForNodeWithFunctionReference(ifStmt.json("test")) - val consequentAst = astForNodeWithFunctionReference(ifStmt.json("consequent")) - val alternateAst = safeObj(ifStmt.json, "alternate") - .map { alternate => - astForNodeWithFunctionReference(Obj(alternate)) - } - .getOrElse(Ast()) - // The semantics of if statement children is partially defined by there order value. - // The consequentAst must have order == 2 and alternateAst must have order == 3. - // Only to avoid collision we set testAst to 1 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(testAst, 1) - setOrderExplicitly(consequentAst, 2) - setOrderExplicitly(alternateAst, 3) - Ast(ifNode) - .withChild(testAst) - .withConditionEdge(ifNode, testAst.nodes.head) - .withChild(consequentAst) - .withChild(alternateAst) - } - - protected def astForDoWhileStatement(doWhileStmt: BabelNodeInfo): Ast = { - val whileNode = createControlStructureNode(doWhileStmt, ControlStructureTypes.DO) - val testAst = astForNodeWithFunctionReference(doWhileStmt.json("test")) - val bodyAst = astForNodeWithFunctionReference(doWhileStmt.json("body")) - // The semantics of do-while statement children is partially defined by there order value. - // The bodyAst must have order == 1. Only to avoid collision we set testAst to 2 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(bodyAst, 1) - setOrderExplicitly(testAst, 2) - Ast(whileNode).withChild(bodyAst).withChild(testAst).withConditionEdge(whileNode, testAst.nodes.head) - } - - protected def astForWhileStatement(whileStmt: BabelNodeInfo): Ast = { - val whileNode = createControlStructureNode(whileStmt, ControlStructureTypes.WHILE) - val testAst = astForNodeWithFunctionReference(whileStmt.json("test")) - val bodyAst = astForNodeWithFunctionReference(whileStmt.json("body")) - // The semantics of while statement children is partially defined by there order value. - // The bodyAst must have order == 2. Only to avoid collision we set testAst to 1 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(testAst, 1) - setOrderExplicitly(bodyAst, 2) - Ast(whileNode).withChild(testAst).withConditionEdge(whileNode, testAst.nodes.head).withChild(bodyAst) - } - - protected def astForForStatement(forStmt: BabelNodeInfo): Ast = { - val forNode = createControlStructureNode(forStmt, ControlStructureTypes.FOR) - val initAst = safeObj(forStmt.json, "init") - .map { init => - astForNodeWithFunctionReference(Obj(init)) - } - .getOrElse(Ast()) - val testAst = safeObj(forStmt.json, "test") - .map { test => - astForNodeWithFunctionReference(Obj(test)) - } - .getOrElse(Ast(literalNode(forStmt, "true", Option(Defines.Boolean)))) - val updateAst = safeObj(forStmt.json, "update") - .map { update => - astForNodeWithFunctionReference(Obj(update)) - } - .getOrElse(Ast()) - val bodyAst = astForNodeWithFunctionReference(forStmt.json("body")) - - // The semantics of for statement children is defined by there order value. - // Thus we set the here explicitly and do not rely on the usual consecutive - // ordering. - setOrderExplicitly(initAst, 1) - setOrderExplicitly(testAst, 2) - setOrderExplicitly(updateAst, 3) - setOrderExplicitly(bodyAst, 4) - Ast(forNode).withChild(initAst).withChild(testAst).withChild(updateAst).withChild(bodyAst) - } - - protected def astForLabeledStatement(labelStmt: BabelNodeInfo): Ast = { - val labelName = code(labelStmt.json("label")) - val labeledNode = NewJumpTarget() - .parserTypeName(labelStmt.node.toString) - .name(labelName) - .code(s"$labelName:") - .lineNumber(labelStmt.lineNumber) - .columnNumber(labelStmt.columnNumber) - - val blockNode = createBlockNode(labelStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val bodyAst = astForNodeWithFunctionReference(labelStmt.json("body")) - scope.popScope() - localAstParentStack.pop() - - val labelAsts = List(Ast(labeledNode), bodyAst) - setArgumentIndices(labelAsts) - blockAst(blockNode, labelAsts) - } - - protected def astForBreakStatement(breakStmt: BabelNodeInfo): Ast = { - val labelAst = safeObj(breakStmt.json, "label") - .map { label => - val labelNode = Obj(label) - val labelCode = code(labelNode) - Ast( - NewJumpLabel() - .parserTypeName(breakStmt.node.toString) - .name(labelCode) - .code(labelCode) - .lineNumber(breakStmt.lineNumber) - .columnNumber(breakStmt.columnNumber) - .order(1) - ) - } - .getOrElse(Ast()) - Ast(createControlStructureNode(breakStmt, ControlStructureTypes.BREAK)).withChild(labelAst) - } - - protected def astForContinueStatement(continueStmt: BabelNodeInfo): Ast = { - val labelAst = safeObj(continueStmt.json, "label") - .map { label => - val labelNode = Obj(label) - val labelCode = code(labelNode) - Ast( - NewJumpLabel() - .parserTypeName(continueStmt.node.toString) - .name(labelCode) - .code(labelCode) - .lineNumber(continueStmt.lineNumber) - .columnNumber(continueStmt.columnNumber) - .order(1) - ) - } - .getOrElse(Ast()) - Ast(createControlStructureNode(continueStmt, ControlStructureTypes.CONTINUE)).withChild(labelAst) - } - - protected def astForThrowStatement(throwStmt: BabelNodeInfo): Ast = { - val argumentAst = astForNodeWithFunctionReference(throwStmt.json("argument")) - val throwCallNode = - callNode(throwStmt, throwStmt.code, ".throw", DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(throwCallNode, argAsts) - } - - private def astsForSwitchCase(switchCase: BabelNodeInfo): List[Ast] = { - val labelAst = Ast(createJumpTarget(switchCase)) - val testAsts = safeObj(switchCase.json, "test").map(t => astForNodeWithFunctionReference(Obj(t))).toList - val consequentAsts = astForNodes(switchCase.json("consequent").arr.toList) - labelAst +: (testAsts ++ consequentAsts) - } - - protected def astForSwitchStatement(switchStmt: BabelNodeInfo): Ast = { - val switchNode = createControlStructureNode(switchStmt, ControlStructureTypes.SWITCH) - - // The semantics of switch statement children is partially defined by there order value. - // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 - // because the semantics of it is already indicated via the condition edge. - val switchExpressionAst = astForNodeWithFunctionReference(switchStmt.json("discriminant")) - setOrderExplicitly(switchExpressionAst, 1) - - val blockNode = createBlockNode(switchStmt).order(2) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val casesAsts = switchStmt.json("cases").arr.flatMap(c => astsForSwitchCase(createBabelNodeInfo(c))) - setArgumentIndices(casesAsts.toList) - - scope.popScope() - localAstParentStack.pop() - - Ast(switchNode) - .withChild(switchExpressionAst) - .withConditionEdge(switchNode, switchExpressionAst.nodes.head) - .withChild(blockAst(blockNode, casesAsts.toList)) - } - - /** De-sugaring from: - * - * for (var i in/of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; var i; while (!(_result = _iterator.next()).done) { i = - * _result.value; body } } - */ - private def astForInOfStatementWithIdentifier(forInOfStmt: BabelNodeInfo, idNodeInfo: BabelNodeInfo): Ast = { - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - - val iteratorCall = - // TODO: add operator to schema - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // loop variable: - val loopVariableName = idNodeInfo.code - - val loopVariableLocalNode = newLocalNode(loopVariableName, Defines.Any).order(0) - val loopVariableNode = identifierNode(forInOfStmt, loopVariableName) - diffGraph.addEdge(localAstParentStack.head, loopVariableLocalNode, EdgeTypes.AST) - scope.addVariableReference(loopVariableName, loopVariableNode) - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode(forInOfStmt, s"!($resultName = $iteratorName.next()).done", Operators.not, DispatchTypes.STATIC_DISPATCH) - - val doneBaseNode = - callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val lhsNode = identifierNode(forInOfStmt, resultName) - - val rhsNode = callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - - val nextMemberNode = createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val nextReceiverNode = - createFieldAccessCallAst(nextBaseNode, nextMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val thisNextNode = identifierNode(forInOfStmt, iteratorName) - - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) - - val doneMemberNode = createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testNode = - createFieldAccessCallAst(doneBaseNode, doneMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) - - val whileLoopAst = Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - - // while loop variable assignment: - val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) - - val baseNode = identifierNode(forInOfStmt, resultName) - - val memberNode = createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val accessAst = createFieldAccessCallAst(baseNode, memberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val loopVariableAssignmentNode = callNode( - forInOfStmt, - s"$loopVariableName = $resultName.value", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), accessAst) - val loopVariableAssignmentAst = callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() - - // end surrounding block: - scope.popScope() - localAstParentStack.pop() - - val blockChildren = - List(iteratorAssignmentAst, Ast(resultNode), Ast(loopVariableNode), whileLoopAst.withChild(whileLoopBlockAst)) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - } - - /** De-sugaring from: - * - * for (expr in/of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; while (!(_result = _iterator.next()).done) { expr = - * _result.value; body } } - */ - private def astForInOfStatementWithExpression(forInOfStmt: BabelNodeInfo, idNodeInfo: BabelNodeInfo): Ast = { - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) +trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + /** Sort all block statements with the following result: + * - all function declarations go first + * - all type aliases that are not plain type references go last + * - all remaining type aliases go before that + * - all remaining statements go second + * + * We do this to get TypeDecls created at the right spot so we can make use of them for the + * type aliases. + */ + private def sortBlockStatements(blockStatements: List[BabelNodeInfo]): List[BabelNodeInfo] = + blockStatements.sortBy { nodeInfo => + nodeInfo.node match + case ImportDeclaration => 0 + case FunctionDeclaration => 1 + case DeclareTypeAlias if isPlainTypeAlias(nodeInfo) => 4 + case TypeAlias if isPlainTypeAlias(nodeInfo) => 4 + case TSTypeAliasDeclaration if isPlainTypeAlias(nodeInfo) => 4 + case DeclareTypeAlias => 3 + case TypeAlias => 3 + case TSTypeAliasDeclaration => 3 + case _ => 2 + } - val collection = forInOfStmt.json("right") - val collectionName = code(collection) + protected def createBlockStatementAsts(json: Value): List[Ast] = + val blockStmts = sortBlockStatements(json.arr.map(createBabelNodeInfo).toList) + val blockAsts = blockStmts.map(stmt => astForNodeWithFunctionReferenceAndCall(stmt.json)) + setArgumentIndices(blockAsts) + blockAsts + + protected def astForWithStatement(withStatement: BabelNodeInfo): Ast = + val blockNode = createBlockNode(withStatement) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val objectAst = astForNodeWithFunctionReferenceAndCall(withStatement.json("object")) + val bodyNodeInfo = createBabelNodeInfo(withStatement.json("body")) + val bodyAsts = bodyNodeInfo.node match + case BlockStatement => createBlockStatementAsts(bodyNodeInfo.json("body")) + case _ => List(astForNodeWithFunctionReferenceAndCall(bodyNodeInfo.json)) + val blockStatementAsts = objectAst +: bodyAsts + setArgumentIndices(blockStatementAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, blockStatementAsts) + + protected def astForBlockStatement(block: BabelNodeInfo): Ast = + val blockNode = createBlockNode(block) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val blockStatementAsts = createBlockStatementAsts(block.json("body")) + setArgumentIndices(blockStatementAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, blockStatementAsts) + + protected def astForReturnStatement(ret: BabelNodeInfo): Ast = + val retCode = ret.code.stripSuffix(";") + val retNode = returnNode(ret, retCode) + safeObj(ret.json, "argument") + .map { argument => + val argAst = astForNodeWithFunctionReference(Obj(argument)) + returnAst(retNode, List(argAst)) + } + .getOrElse(Ast(retNode)) + + private def astForCatchClause(catchClause: BabelNodeInfo): Ast = + astForNodeWithFunctionReference(catchClause.json("body")) + + protected def astForTryStatement(tryStmt: BabelNodeInfo): Ast = + val tryNode = createControlStructureNode(tryStmt, ControlStructureTypes.TRY) + val bodyAst = astForNodeWithFunctionReference(tryStmt.json("block")) + val catchAst = safeObj(tryStmt.json, "handler") + .map { handler => + astForCatchClause(createBabelNodeInfo(Obj(handler))) + } + .getOrElse(Ast()) + val finalizerAst = safeObj(tryStmt.json, "finalizer") + .map { finalizer => + astForNodeWithFunctionReference(Obj(finalizer)) + } + .getOrElse(Ast()) + // The semantics of try statement children is defined by there order value. + // Thus we set the here explicitly and do not rely on the usual consecutive + // ordering. + setOrderExplicitly(bodyAst, 1) + setOrderExplicitly(catchAst, 2) + setOrderExplicitly(finalizerAst, 3) + Ast(tryNode).withChildren(List(bodyAst, catchAst, finalizerAst)) + end astForTryStatement + + def astForIfStatement(ifStmt: BabelNodeInfo): Ast = + val ifNode = createControlStructureNode(ifStmt, ControlStructureTypes.IF) + val testAst = astForNodeWithFunctionReference(ifStmt.json("test")) + val consequentAst = astForNodeWithFunctionReference(ifStmt.json("consequent")) + val alternateAst = safeObj(ifStmt.json, "alternate") + .map { alternate => + astForNodeWithFunctionReference(Obj(alternate)) + } + .getOrElse(Ast()) + // The semantics of if statement children is partially defined by there order value. + // The consequentAst must have order == 2 and alternateAst must have order == 3. + // Only to avoid collision we set testAst to 1 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(testAst, 1) + setOrderExplicitly(consequentAst, 2) + setOrderExplicitly(alternateAst, 3) + Ast(ifNode) + .withChild(testAst) + .withConditionEdge(ifNode, testAst.nodes.head) + .withChild(consequentAst) + .withChild(alternateAst) + end astForIfStatement + + protected def astForDoWhileStatement(doWhileStmt: BabelNodeInfo): Ast = + val whileNode = createControlStructureNode(doWhileStmt, ControlStructureTypes.DO) + val testAst = astForNodeWithFunctionReference(doWhileStmt.json("test")) + val bodyAst = astForNodeWithFunctionReference(doWhileStmt.json("body")) + // The semantics of do-while statement children is partially defined by there order value. + // The bodyAst must have order == 1. Only to avoid collision we set testAst to 2 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(bodyAst, 1) + setOrderExplicitly(testAst, 2) + Ast(whileNode).withChild(bodyAst).withChild(testAst).withConditionEdge( + whileNode, + testAst.nodes.head + ) - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) + protected def astForWhileStatement(whileStmt: BabelNodeInfo): Ast = + val whileNode = createControlStructureNode(whileStmt, ControlStructureTypes.WHILE) + val testAst = astForNodeWithFunctionReference(whileStmt.json("test")) + val bodyAst = astForNodeWithFunctionReference(whileStmt.json("body")) + // The semantics of while statement children is partially defined by there order value. + // The bodyAst must have order == 2. Only to avoid collision we set testAst to 1 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(testAst, 1) + setOrderExplicitly(bodyAst, 2) + Ast(whileNode).withChild(testAst).withConditionEdge( + whileNode, + testAst.nodes.head + ).withChild(bodyAst) + + protected def astForForStatement(forStmt: BabelNodeInfo): Ast = + val forNode = createControlStructureNode(forStmt, ControlStructureTypes.FOR) + val initAst = safeObj(forStmt.json, "init") + .map { init => + astForNodeWithFunctionReference(Obj(init)) + } + .getOrElse(Ast()) + val testAst = safeObj(forStmt.json, "test") + .map { test => + astForNodeWithFunctionReference(Obj(test)) + } + .getOrElse(Ast(literalNode(forStmt, "true", Option(Defines.Boolean)))) + val updateAst = safeObj(forStmt.json, "update") + .map { update => + astForNodeWithFunctionReference(Obj(update)) + } + .getOrElse(Ast()) + val bodyAst = astForNodeWithFunctionReference(forStmt.json("body")) + + // The semantics of for statement children is defined by there order value. + // Thus we set the here explicitly and do not rely on the usual consecutive + // ordering. + setOrderExplicitly(initAst, 1) + setOrderExplicitly(testAst, 2) + setOrderExplicitly(updateAst, 3) + setOrderExplicitly(bodyAst, 4) + Ast(forNode).withChild(initAst).withChild(testAst).withChild(updateAst).withChild(bodyAst) + end astForForStatement + + protected def astForLabeledStatement(labelStmt: BabelNodeInfo): Ast = + val labelName = code(labelStmt.json("label")) + val labeledNode = NewJumpTarget() + .parserTypeName(labelStmt.node.toString) + .name(labelName) + .code(s"$labelName:") + .lineNumber(labelStmt.lineNumber) + .columnNumber(labelStmt.columnNumber) + + val blockNode = createBlockNode(labelStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val bodyAst = astForNodeWithFunctionReference(labelStmt.json("body")) + scope.popScope() + localAstParentStack.pop() + + val labelAsts = List(Ast(labeledNode), bodyAst) + setArgumentIndices(labelAsts) + blockAst(blockNode, labelAsts) + end astForLabeledStatement + + protected def astForBreakStatement(breakStmt: BabelNodeInfo): Ast = + val labelAst = safeObj(breakStmt.json, "label") + .map { label => + val labelNode = Obj(label) + val labelCode = code(labelNode) + Ast( + NewJumpLabel() + .parserTypeName(breakStmt.node.toString) + .name(labelCode) + .code(labelCode) + .lineNumber(breakStmt.lineNumber) + .columnNumber(breakStmt.columnNumber) + .order(1) + ) + } + .getOrElse(Ast()) + Ast(createControlStructureNode(breakStmt, ControlStructureTypes.BREAK)).withChild(labelAst) + + protected def astForContinueStatement(continueStmt: BabelNodeInfo): Ast = + val labelAst = safeObj(continueStmt.json, "label") + .map { label => + val labelNode = Obj(label) + val labelCode = code(labelNode) + Ast( + NewJumpLabel() + .parserTypeName(continueStmt.node.toString) + .name(labelCode) + .code(labelCode) + .lineNumber(continueStmt.lineNumber) + .columnNumber(continueStmt.columnNumber) + .order(1) + ) + } + .getOrElse(Ast()) + Ast(createControlStructureNode(continueStmt, ControlStructureTypes.CONTINUE)).withChild( + labelAst + ) + end astForContinueStatement + + protected def astForThrowStatement(throwStmt: BabelNodeInfo): Ast = + val argumentAst = astForNodeWithFunctionReference(throwStmt.json("argument")) + val throwCallNode = + callNode(throwStmt, throwStmt.code, ".throw", DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(throwCallNode, argAsts) + + private def astsForSwitchCase(switchCase: BabelNodeInfo): List[Ast] = + val labelAst = Ast(createJumpTarget(switchCase)) + val testAsts = safeObj(switchCase.json, "test").map(t => + astForNodeWithFunctionReference(Obj(t)) + ).toList + val consequentAsts = astForNodes(switchCase.json("consequent").arr.toList) + labelAst +: (testAsts ++ consequentAsts) + + protected def astForSwitchStatement(switchStmt: BabelNodeInfo): Ast = + val switchNode = createControlStructureNode(switchStmt, ControlStructureTypes.SWITCH) + + // The semantics of switch statement children is partially defined by there order value. + // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 + // because the semantics of it is already indicated via the condition edge. + val switchExpressionAst = astForNodeWithFunctionReference(switchStmt.json("discriminant")) + setOrderExplicitly(switchExpressionAst, 1) + + val blockNode = createBlockNode(switchStmt).order(2) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val casesAsts = + switchStmt.json("cases").arr.flatMap(c => astsForSwitchCase(createBabelNodeInfo(c))) + setArgumentIndices(casesAsts.toList) + + scope.popScope() + localAstParentStack.pop() + + Ast(switchNode) + .withChild(switchExpressionAst) + .withConditionEdge(switchNode, switchExpressionAst.nodes.head) + .withChild(blockAst(blockNode, casesAsts.toList)) + end astForSwitchStatement + + /** De-sugaring from: + * + * for (var i in/of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; var i; while (!(_result = + * _iterator.next()).done) { i = _result.value; body } } + */ + private def astForInOfStatementWithIdentifier( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = + // TODO: add operator to schema + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) + + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // loop variable: + val loopVariableName = idNodeInfo.code + + val loopVariableLocalNode = newLocalNode(loopVariableName, Defines.Any).order(0) + val loopVariableNode = identifierNode(forInOfStmt, loopVariableName) + diffGraph.addEdge(localAstParentStack.head, loopVariableLocalNode, EdgeTypes.AST) + scope.addVariableReference(loopVariableName, loopVariableNode) + + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) + + val doneBaseNode = + callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val lhsNode = identifierNode(forInOfStmt, resultName) + + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val thisNextNode = identifierNode(forInOfStmt, iteratorName) + + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) + + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) + + val baseNode = identifierNode(forInOfStmt, resultName) + + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val iteratorCall = - // TODO: add operator to schema - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"$loopVariableName = $resultName.value", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), accessAst) + val loopVariableAssignmentAst = + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockChildren = + List( + iteratorAssignmentAst, + Ast(resultNode), + Ast(loopVariableNode), + whileLoopAst.withChild(whileLoopBlockAst) + ) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForInOfStatementWithIdentifier + + /** De-sugaring from: + * + * for (expr in/of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; while (!(_result = + * _iterator.next()).done) { expr = _result.value; body } } + */ + private def astForInOfStatementWithExpression( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = + // TODO: add operator to schema + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) + + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) + + val doneBaseNode = + callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val lhsNode = identifierNode(forInOfStmt, resultName) + + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val thisNextNode = identifierNode(forInOfStmt, iteratorName) + + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) + + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val whileLoopVariableNode = astForNode(idNodeInfo.json) + + val baseNode = identifierNode(forInOfStmt, resultName) + + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"${idNodeInfo.code} = $resultName.value", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + val loopVariableAssignmentArgs = List(whileLoopVariableNode, accessAst) + val loopVariableAssignmentAst = + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockChildren = + List(iteratorAssignmentAst, Ast(resultNode), whileLoopAst.withChild(whileLoopBlockAst)) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForInOfStatementWithExpression + + /** De-sugaring from: + * + * for(var {a, b, c} of obj) { body } + * + * to: + * + * { var _iterator = .iterator(obj); var _result; var a; var b; var c; while + * (!(_result = _iterator.next()).done) { a = _result.value.a; b = _result.value.b; c = + * _result.value.c; body } } + */ + private def astForInOfStatementWithObject( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + // TODO: add operator to schema + val iteratorCall = + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) + + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // loop variable: + val loopVariableNames = idNodeInfo.json("properties").arr.toList.map(code) + + val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) + val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) + loopVariableLocalNodes.foreach(diffGraph.addEdge( + localAstParentStack.head, + _, + EdgeTypes.AST + )) + loopVariableNames.zip(loopVariableNodes).foreach { + case (loopVariableName, loopVariableNode) => + scope.addVariableReference(loopVariableName, loopVariableNode) + } - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode(forInOfStmt, s"!($resultName = $iteratorName.next()).done", Operators.not, DispatchTypes.STATIC_DISPATCH) - - val doneBaseNode = - callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val lhsNode = identifierNode(forInOfStmt, resultName) - - val rhsNode = callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - - val nextMemberNode = createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val nextReceiverNode = - createFieldAccessCallAst(nextBaseNode, nextMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val thisNextNode = identifierNode(forInOfStmt, iteratorName) - - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) - - val doneMemberNode = createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testNode = - createFieldAccessCallAst(doneBaseNode, doneMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) - - val whileLoopAst = Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - - // while loop variable assignment: - val whileLoopVariableNode = astForNode(idNodeInfo.json) - - val baseNode = identifierNode(forInOfStmt, resultName) - - val memberNode = createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val accessAst = createFieldAccessCallAst(baseNode, memberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val loopVariableAssignmentNode = callNode( - forInOfStmt, - s"${idNodeInfo.code} = $resultName.value", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val loopVariableAssignmentArgs = List(whileLoopVariableNode, accessAst) - val loopVariableAssignmentAst = callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) + + val doneBaseNode = callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - // end surrounding block: - scope.popScope() - localAstParentStack.pop() + val lhsNode = identifierNode(forInOfStmt, resultName) + + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val thisNextNode = identifierNode(forInOfStmt, iteratorName) + + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) + + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val loopVariableAssignmentAsts = loopVariableNames.map { loopVariableName => + val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) + val baseNode = identifierNode(forInOfStmt, resultName) + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val variableMemberNode = + createFieldIdentifierNode( + loopVariableName, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val variableAccessAst = + createFieldAccessCallAst( + accessAst, + variableMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"$loopVariableName = $resultName.value.$loopVariableName", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + } - val blockChildren = List(iteratorAssignmentAst, Ast(resultNode), whileLoopAst.withChild(whileLoopBlockAst)) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - } + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockNodeChildren = + List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( + Ast(_) + ) :+ whileLoopAst.withChild( + whileLoopBlockAst + ) + setArgumentIndices(blockNodeChildren) + blockAst(blockNode, blockNodeChildren) + end astForInOfStatementWithObject + + /** De-sugaring from: + * + * for(var [a, b, c] of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; var a; var b; var c; while + * (!(_result = _iterator.next()).done) { a = _result.value[0]; b = _result.value[1]; c = + * _result.value[2]; body } } + */ + private def astForInOfStatementWithArray( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) - /** De-sugaring from: - * - * for(var {a, b, c} of obj) { body } - * - * to: - * - * { var _iterator = .iterator(obj); var _result; var a; var b; var c; while (!(_result = - * _iterator.next()).done) { a = _result.value.a; b = _result.value.b; c = _result.value.c; body } } - */ - private def astForInOfStatementWithObject(forInOfStmt: BabelNodeInfo, idNodeInfo: BabelNodeInfo): Ast = { - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // loop variable: + val loopVariableNames = idNodeInfo.json("elements").arr.toList.map(code) + + val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) + val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) + loopVariableLocalNodes.foreach(diffGraph.addEdge( + localAstParentStack.head, + _, + EdgeTypes.AST + )) + loopVariableNames.zip(loopVariableNodes).foreach { + case (loopVariableName, loopVariableNode) => + scope.addVariableReference(loopVariableName, loopVariableNode) + } - val collection = forInOfStmt.json("right") - val collectionName = code(collection) + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) + + val doneBaseNode = callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - // TODO: add operator to schema - val iteratorCall = - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // loop variable: - val loopVariableNames = idNodeInfo.json("properties").arr.toList.map(code) - - val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) - val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) - loopVariableLocalNodes.foreach(diffGraph.addEdge(localAstParentStack.head, _, EdgeTypes.AST)) - loopVariableNames.zip(loopVariableNodes).foreach { case (loopVariableName, loopVariableNode) => - scope.addVariableReference(loopVariableName, loopVariableNode) - } - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode(forInOfStmt, s"!($resultName = $iteratorName.next()).done", Operators.not, DispatchTypes.STATIC_DISPATCH) - - val doneBaseNode = callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val lhsNode = identifierNode(forInOfStmt, resultName) - - val rhsNode = callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - - val nextMemberNode = createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val nextReceiverNode = - createFieldAccessCallAst(nextBaseNode, nextMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val thisNextNode = identifierNode(forInOfStmt, iteratorName) - - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) - - val doneMemberNode = createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testNode = - createFieldAccessCallAst(doneBaseNode, doneMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) - - val whileLoopAst = Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - - // while loop variable assignment: - val loopVariableAssignmentAsts = loopVariableNames.map { loopVariableName => - val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) - val baseNode = identifierNode(forInOfStmt, resultName) - val memberNode = createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val accessAst = createFieldAccessCallAst(baseNode, memberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val variableMemberNode = - createFieldIdentifierNode(loopVariableName, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val variableAccessAst = - createFieldAccessCallAst(accessAst, variableMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val loopVariableAssignmentNode = callNode( - forInOfStmt, - s"$loopVariableName = $resultName.value.$loopVariableName", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) - callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - } - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() - - // end surrounding block: - scope.popScope() - localAstParentStack.pop() - - val blockNodeChildren = - List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map(Ast(_)) :+ whileLoopAst.withChild( - whileLoopBlockAst - ) - setArgumentIndices(blockNodeChildren) - blockAst(blockNode, blockNodeChildren) - } - - /** De-sugaring from: - * - * for(var [a, b, c] of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; var a; var b; var c; while (!(_result = - * _iterator.next()).done) { a = _result.value[0]; b = _result.value[1]; c = _result.value[2]; body } } - */ - private def astForInOfStatementWithArray(forInOfStmt: BabelNodeInfo, idNodeInfo: BabelNodeInfo): Ast = { - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - - val iteratorCall = callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // loop variable: - val loopVariableNames = idNodeInfo.json("elements").arr.toList.map(code) - - val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) - val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) - loopVariableLocalNodes.foreach(diffGraph.addEdge(localAstParentStack.head, _, EdgeTypes.AST)) - loopVariableNames.zip(loopVariableNodes).foreach { case (loopVariableName, loopVariableNode) => - scope.addVariableReference(loopVariableName, loopVariableNode) - } - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode(forInOfStmt, s"!($resultName = $iteratorName.next()).done", Operators.not, DispatchTypes.STATIC_DISPATCH) - - val doneBaseNode = callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val lhsNode = identifierNode(forInOfStmt, resultName) - - val rhsNode = callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - - val nextMemberNode = createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val nextReceiverNode = - createFieldAccessCallAst(nextBaseNode, nextMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val thisNextNode = identifierNode(forInOfStmt, iteratorName) - - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) - - val doneMemberNode = createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testNode = - createFieldAccessCallAst(doneBaseNode, doneMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) - - val whileLoopAst = Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - - // while loop variable assignment: - val loopVariableAssignmentAsts = loopVariableNames.zipWithIndex.map { case (loopVariableName, index) => - val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) - val baseNode = identifierNode(forInOfStmt, resultName) - val memberNode = createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val accessAst = createFieldAccessCallAst(baseNode, memberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val variableMemberNode = literalNode(forInOfStmt, index.toString, dynamicTypeOption = Some(Defines.Number)) - val variableAccessAst = - createIndexAccessCallAst(accessAst, Ast(variableMemberNode), forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val loopVariableAssignmentNode = callNode( - forInOfStmt, - s"$loopVariableName = $resultName.value[$index]", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) - callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - } - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() - - // end surrounding block: - scope.popScope() - localAstParentStack.pop() - - val blockNodeChildren = - List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map(Ast(_)) :+ whileLoopAst.withChild( - whileLoopBlockAst - ) - setArgumentIndices(blockNodeChildren) - blockAst(blockNode, blockNodeChildren) - } - - private def extractLoopVariableNodeInfo(nodeInfo: BabelNodeInfo): Option[BabelNodeInfo] = - nodeInfo.node match { - case AssignmentPattern => - Option(createBabelNodeInfo(nodeInfo.json("left"))) - case VariableDeclaration => - val varDeclNodeInfo = createBabelNodeInfo(nodeInfo.json("declarations").arr.head) - if (varDeclNodeInfo.node == VariableDeclarator) Option(createBabelNodeInfo(varDeclNodeInfo.json("id"))) - else None - case _ => None - } - - protected def astForInOfStatement(forInOfStmt: BabelNodeInfo): Ast = { - val loopVariableNodeInfo = createBabelNodeInfo(forInOfStmt.json("left")) - // check iteration loop variable type: - loopVariableNodeInfo.node match { - case VariableDeclaration | AssignmentPattern => - val idNodeInfo = extractLoopVariableNodeInfo(loopVariableNodeInfo) - idNodeInfo.map(_.node) match { - case Some(ObjectPattern) => astForInOfStatementWithObject(forInOfStmt, idNodeInfo.get) - case Some(ArrayPattern) => astForInOfStatementWithArray(forInOfStmt, idNodeInfo.get) - case Some(Identifier) => astForInOfStatementWithIdentifier(forInOfStmt, idNodeInfo.get) - case _ => notHandledYet(forInOfStmt) - } - case ObjectPattern => astForInOfStatementWithObject(forInOfStmt, loopVariableNodeInfo) - case ArrayPattern => astForInOfStatementWithArray(forInOfStmt, loopVariableNodeInfo) - case Identifier => astForInOfStatementWithIdentifier(forInOfStmt, loopVariableNodeInfo) - case _: Expression => astForInOfStatementWithExpression(forInOfStmt, loopVariableNodeInfo) - case _ => notHandledYet(forInOfStmt) - } - } - -} + val lhsNode = identifierNode(forInOfStmt, resultName) + + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val thisNextNode = identifierNode(forInOfStmt, iteratorName) + + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) + + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val loopVariableAssignmentAsts = + loopVariableNames.zipWithIndex.map { case (loopVariableName, index) => + val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) + val baseNode = identifierNode(forInOfStmt, resultName) + val memberNode = createFieldIdentifierNode( + "value", + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val variableMemberNode = literalNode( + forInOfStmt, + index.toString, + dynamicTypeOption = Some(Defines.Number) + ) + val variableAccessAst = + createIndexAccessCallAst( + accessAst, + Ast(variableMemberNode), + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"$loopVariableName = $resultName.value[$index]", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + } + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockNodeChildren = + List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( + Ast(_) + ) :+ whileLoopAst.withChild( + whileLoopBlockAst + ) + setArgumentIndices(blockNodeChildren) + blockAst(blockNode, blockNodeChildren) + end astForInOfStatementWithArray + + private def extractLoopVariableNodeInfo(nodeInfo: BabelNodeInfo): Option[BabelNodeInfo] = + nodeInfo.node match + case AssignmentPattern => + Option(createBabelNodeInfo(nodeInfo.json("left"))) + case VariableDeclaration => + val varDeclNodeInfo = createBabelNodeInfo(nodeInfo.json("declarations").arr.head) + if varDeclNodeInfo.node == VariableDeclarator then + Option(createBabelNodeInfo(varDeclNodeInfo.json("id"))) + else None + case _ => None + + protected def astForInOfStatement(forInOfStmt: BabelNodeInfo): Ast = + val loopVariableNodeInfo = createBabelNodeInfo(forInOfStmt.json("left")) + // check iteration loop variable type: + loopVariableNodeInfo.node match + case VariableDeclaration | AssignmentPattern => + val idNodeInfo = extractLoopVariableNodeInfo(loopVariableNodeInfo) + idNodeInfo.map(_.node) match + case Some(ObjectPattern) => + astForInOfStatementWithObject(forInOfStmt, idNodeInfo.get) + case Some(ArrayPattern) => + astForInOfStatementWithArray(forInOfStmt, idNodeInfo.get) + case Some(Identifier) => + astForInOfStatementWithIdentifier(forInOfStmt, idNodeInfo.get) + case _ => notHandledYet(forInOfStmt) + case ObjectPattern => astForInOfStatementWithObject(forInOfStmt, loopVariableNodeInfo) + case ArrayPattern => astForInOfStatementWithArray(forInOfStmt, loopVariableNodeInfo) + case Identifier => astForInOfStatementWithIdentifier(forInOfStmt, loopVariableNodeInfo) + case _: Expression => + astForInOfStatementWithExpression(forInOfStmt, loopVariableNodeInfo) + case _ => notHandledYet(forInOfStmt) + end astForInOfStatement +end AstForStatementsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala index c448186a..320a811f 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala @@ -5,108 +5,115 @@ import io.appthreat.jssrc2cpg.parser.BabelAst.* import io.appthreat.x2cpg.{Ast, ValidationMode} import ujson.Obj -trait AstForTemplateDomCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => +trait AstForTemplateDomCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => - protected def astForJsxElement(jsxElem: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode(jsxElem.node.toString, jsxElem.code, jsxElem.lineNumber, jsxElem.columnNumber) - val openingAst = astForNodeWithFunctionReference(jsxElem.json("openingElement")) - val childrenAsts = astForNodes(jsxElem.json("children").arr.toList) - val closingAst = - safeObj(jsxElem.json, "closingElement") - .map(e => astForNodeWithFunctionReference(Obj(e))) - .getOrElse(Ast()) - val allChildrenAsts = openingAst +: childrenAsts :+ closingAst - setArgumentIndices(allChildrenAsts) - Ast(domNode).withChildren(allChildrenAsts) - } + protected def astForJsxElement(jsxElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxElem.node.toString, + jsxElem.code, + jsxElem.lineNumber, + jsxElem.columnNumber + ) + val openingAst = astForNodeWithFunctionReference(jsxElem.json("openingElement")) + val childrenAsts = astForNodes(jsxElem.json("children").arr.toList) + val closingAst = + safeObj(jsxElem.json, "closingElement") + .map(e => astForNodeWithFunctionReference(Obj(e))) + .getOrElse(Ast()) + val allChildrenAsts = openingAst +: childrenAsts :+ closingAst + setArgumentIndices(allChildrenAsts) + Ast(domNode).withChildren(allChildrenAsts) - protected def astForJsxFragment(jsxFragment: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode( - jsxFragment.node.toString, - jsxFragment.code, - jsxFragment.lineNumber, - jsxFragment.columnNumber - ) - val childrenAsts = astForNodes(jsxFragment.json("children").arr.toList) - setArgumentIndices(childrenAsts) - Ast(domNode).withChildren(childrenAsts) - } + protected def astForJsxFragment(jsxFragment: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxFragment.node.toString, + jsxFragment.code, + jsxFragment.lineNumber, + jsxFragment.columnNumber + ) + val childrenAsts = astForNodes(jsxFragment.json("children").arr.toList) + setArgumentIndices(childrenAsts) + Ast(domNode).withChildren(childrenAsts) - protected def astForJsxAttribute(jsxAttr: BabelNodeInfo): Ast = { - // A colon in front of a JSXAttribute cant be parsed by Babel. - // Hence, we strip it away with astgen and restore it here. - // parserResult.fileContent contains the unmodified Vue.js source code for the current file. - // We look at the previous character there and re-add the colon if needed. - val colon = pos(jsxAttr.json) - .collect { - case position if position > 0 && parserResult.fileContent.substring(position - 1, position) == ":" => ":" - } - .getOrElse("") - val domNode = - createTemplateDomNode( - jsxAttr.node.toString, - s"$colon${jsxAttr.code}", - jsxAttr.lineNumber, - jsxAttr.columnNumber.map(_ - colon.length) - ) - val valueAst = safeObj(jsxAttr.json, "value") - .map(e => astForNodeWithFunctionReference(Obj(e))) - .getOrElse(Ast()) - setArgumentIndices(List(valueAst)) - Ast(domNode).withChild(valueAst) - } + protected def astForJsxAttribute(jsxAttr: BabelNodeInfo): Ast = + // A colon in front of a JSXAttribute cant be parsed by Babel. + // Hence, we strip it away with astgen and restore it here. + // parserResult.fileContent contains the unmodified Vue.js source code for the current file. + // We look at the previous character there and re-add the colon if needed. + val colon = pos(jsxAttr.json) + .collect { + case position + if position > 0 && parserResult.fileContent.substring( + position - 1, + position + ) == ":" => ":" + } + .getOrElse("") + val domNode = + createTemplateDomNode( + jsxAttr.node.toString, + s"$colon${jsxAttr.code}", + jsxAttr.lineNumber, + jsxAttr.columnNumber.map(_ - colon.length) + ) + val valueAst = safeObj(jsxAttr.json, "value") + .map(e => astForNodeWithFunctionReference(Obj(e))) + .getOrElse(Ast()) + setArgumentIndices(List(valueAst)) + Ast(domNode).withChild(valueAst) + end astForJsxAttribute - protected def astForJsxOpeningElement(jsxOpeningElem: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode( - jsxOpeningElem.node.toString, - jsxOpeningElem.code, - jsxOpeningElem.lineNumber, - jsxOpeningElem.columnNumber - ) - val childrenAsts = astForNodes(jsxOpeningElem.json("attributes").arr.toList) - setArgumentIndices(childrenAsts) - Ast(domNode).withChildren(childrenAsts) - } + protected def astForJsxOpeningElement(jsxOpeningElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxOpeningElem.node.toString, + jsxOpeningElem.code, + jsxOpeningElem.lineNumber, + jsxOpeningElem.columnNumber + ) + val childrenAsts = astForNodes(jsxOpeningElem.json("attributes").arr.toList) + setArgumentIndices(childrenAsts) + Ast(domNode).withChildren(childrenAsts) - protected def astForJsxClosingElement(jsxClosingElem: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode( - jsxClosingElem.node.toString, - jsxClosingElem.code, - jsxClosingElem.lineNumber, - jsxClosingElem.columnNumber - ) - Ast(domNode) - } + protected def astForJsxClosingElement(jsxClosingElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxClosingElem.node.toString, + jsxClosingElem.code, + jsxClosingElem.lineNumber, + jsxClosingElem.columnNumber + ) + Ast(domNode) - protected def astForJsxText(jsxText: BabelNodeInfo): Ast = - Ast(createTemplateDomNode(jsxText.node.toString, jsxText.code, jsxText.lineNumber, jsxText.columnNumber)) + protected def astForJsxText(jsxText: BabelNodeInfo): Ast = + Ast(createTemplateDomNode( + jsxText.node.toString, + jsxText.code, + jsxText.lineNumber, + jsxText.columnNumber + )) - protected def astForJsxExprContainer(jsxExprContainer: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode( - jsxExprContainer.node.toString, - jsxExprContainer.code, - jsxExprContainer.lineNumber, - jsxExprContainer.columnNumber - ) - val nodeInfo = createBabelNodeInfo(jsxExprContainer.json("expression")) - val exprAst = nodeInfo.node match { - case JSXEmptyExpression => Ast() - case _ => astForNodeWithFunctionReference(nodeInfo.json) - } - setArgumentIndices(List(exprAst)) - Ast(domNode).withChild(exprAst) - } + protected def astForJsxExprContainer(jsxExprContainer: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxExprContainer.node.toString, + jsxExprContainer.code, + jsxExprContainer.lineNumber, + jsxExprContainer.columnNumber + ) + val nodeInfo = createBabelNodeInfo(jsxExprContainer.json("expression")) + val exprAst = nodeInfo.node match + case JSXEmptyExpression => Ast() + case _ => astForNodeWithFunctionReference(nodeInfo.json) + setArgumentIndices(List(exprAst)) + Ast(domNode).withChild(exprAst) - protected def astForJsxSpreadAttribute(jsxSpreadAttr: BabelNodeInfo): Ast = { - val domNode = createTemplateDomNode( - jsxSpreadAttr.node.toString, - jsxSpreadAttr.code, - jsxSpreadAttr.lineNumber, - jsxSpreadAttr.columnNumber - ) - val argAst = astForNodeWithFunctionReference(jsxSpreadAttr.json("argument")) - setArgumentIndices(List(argAst)) - Ast(domNode).withChild(argAst) - } - -} + protected def astForJsxSpreadAttribute(jsxSpreadAttr: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxSpreadAttr.node.toString, + jsxSpreadAttr.code, + jsxSpreadAttr.lineNumber, + jsxSpreadAttr.columnNumber + ) + val argAst = astForNodeWithFunctionReference(jsxSpreadAttr.json("argument")) + setArgumentIndices(List(argAst)) + Ast(domNode).withChild(argAst) +end AstForTemplateDomCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala index 0b9eb126..8847705e 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala @@ -13,86 +13,95 @@ import ujson.Value import scala.util.Try -trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - - protected def astForTypeAlias(alias: BabelNodeInfo): Ast = { - val (aliasName, aliasFullName) = calcTypeNameAndFullName(alias) - val name = if (hasKey(alias.json, "right")) { - typeFor(createBabelNodeInfo(alias.json("right"))) - } else { - typeFor(createBabelNodeInfo(alias.json)) - } - registerType(aliasName, aliasFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val aliasTypeDeclNode = - typeDeclNode(alias, aliasName, aliasFullName, parserResult.filename, alias.code, astParentType, astParentFullName) - seenAliasTypes.add(aliasTypeDeclNode) - - val typeDeclNodeAst = - if (!Defines.JsTypes.contains(name) && !seenAliasTypes.exists(_.name == name)) { - val (typeName, typeFullName) = calcTypeNameAndFullName(alias, Option(name)) - val typeDeclNode_ = typeDeclNode( - alias, - typeName, - typeFullName, - parserResult.filename, - alias.code, - astParentType, - astParentFullName, - alias = Option(aliasFullName) +trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + + protected def astForTypeAlias(alias: BabelNodeInfo): Ast = + val (aliasName, aliasFullName) = calcTypeNameAndFullName(alias) + val name = if hasKey(alias.json, "right") then + typeFor(createBabelNodeInfo(alias.json("right"))) + else + typeFor(createBabelNodeInfo(alias.json)) + registerType(aliasName, aliasFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val aliasTypeDeclNode = + typeDeclNode( + alias, + aliasName, + aliasFullName, + parserResult.filename, + alias.code, + astParentType, + astParentFullName + ) + seenAliasTypes.add(aliasTypeDeclNode) + + val typeDeclNodeAst = + if !Defines.JsTypes.contains(name) && !seenAliasTypes.exists(_.name == name) then + val (typeName, typeFullName) = calcTypeNameAndFullName(alias, Option(name)) + val typeDeclNode_ = typeDeclNode( + alias, + typeName, + typeFullName, + parserResult.filename, + alias.code, + astParentType, + astParentFullName, + alias = Option(aliasFullName) + ) + registerType(typeName, typeFullName) + Ast(typeDeclNode_) + else + seenAliasTypes + .collectFirst { + case typeDecl if typeDecl.name == name => + Ast(typeDecl.aliasTypeFullName(aliasFullName)) + } + .getOrElse(Ast()) + + // adding all class methods / functions and uninitialized, non-static members + val membersAndInitializers = (alias.node match + case TSTypeLiteral => classMembersForTypeAlias(alias) + case ObjectPattern => Try(alias.json("properties").arr).toOption.toSeq.flatten + case _ => classMembersForTypeAlias(createBabelNodeInfo(alias.json("typeAnnotation"))) + ).filter(member => + isClassMethodOrUninitializedMemberOrObjectProperty(member) && !isStaticMember(member) ) - registerType(typeName, typeFullName) - Ast(typeDeclNode_) - } else { - seenAliasTypes - .collectFirst { case typeDecl if typeDecl.name == name => Ast(typeDecl.aliasTypeFullName(aliasFullName)) } - .getOrElse(Ast()) - } - - // adding all class methods / functions and uninitialized, non-static members - val membersAndInitializers = (alias.node match { - case TSTypeLiteral => classMembersForTypeAlias(alias) - case ObjectPattern => Try(alias.json("properties").arr).toOption.toSeq.flatten - case _ => classMembersForTypeAlias(createBabelNodeInfo(alias.json("typeAnnotation"))) - }).filter(member => isClassMethodOrUninitializedMemberOrObjectProperty(member) && !isStaticMember(member)) - .map(m => astForClassMember(m, aliasTypeDeclNode)) - typeDeclNodeAst.root.foreach(diffGraph.addEdge(methodAstParentStack.head, _, EdgeTypes.AST)) - Ast(aliasTypeDeclNode).withChildren(membersAndInitializers) - } - - private def isConstructor(json: Value): Boolean = createBabelNodeInfo(json).node match { - case TSConstructSignatureDeclaration => true - case _ => safeStr(json, "kind").contains("constructor") - } - - private def classMembers(clazz: BabelNodeInfo, withConstructor: Boolean = true): Seq[Value] = { - val allMembers = Try(clazz.json("body")("body").arr).toOption.toSeq.flatten - val dynamicallyDeclaredMembers = - allMembers - .find(isConstructor) - .flatMap(c => Try(c("body")("body").arr).toOption) - .toSeq - .flatten - .filter(isInitializedMember) - if (withConstructor) { - allMembers ++ dynamicallyDeclaredMembers - } else { - allMembers.filterNot(isConstructor) ++ dynamicallyDeclaredMembers - } - } - - private def classMembersForTypeAlias(alias: BabelNodeInfo): Seq[Value] = - Try(alias.json("members").arr).toOption.toSeq.flatten - - private def createFakeConstructor( - code: String, - forElem: BabelNodeInfo, - methodBlockContent: List[Ast] = List.empty - ): MethodAst = { - val fakeConstructorCode = s"""{ + .map(m => astForClassMember(m, aliasTypeDeclNode)) + typeDeclNodeAst.root.foreach(diffGraph.addEdge(methodAstParentStack.head, _, EdgeTypes.AST)) + Ast(aliasTypeDeclNode).withChildren(membersAndInitializers) + end astForTypeAlias + + private def isConstructor(json: Value): Boolean = createBabelNodeInfo(json).node match + case TSConstructSignatureDeclaration => true + case _ => safeStr(json, "kind").contains("constructor") + + private def classMembers(clazz: BabelNodeInfo, withConstructor: Boolean = true): Seq[Value] = + val allMembers = Try(clazz.json("body")("body").arr).toOption.toSeq.flatten + val dynamicallyDeclaredMembers = + allMembers + .find(isConstructor) + .flatMap(c => Try(c("body")("body").arr).toOption) + .toSeq + .flatten + .filter(isInitializedMember) + if withConstructor then + allMembers ++ dynamicallyDeclaredMembers + else + allMembers.filterNot(isConstructor) ++ dynamicallyDeclaredMembers + + private def classMembersForTypeAlias(alias: BabelNodeInfo): Seq[Value] = + Try(alias.json("members").arr).toOption.toSeq.flatten + + private def createFakeConstructor( + code: String, + forElem: BabelNodeInfo, + methodBlockContent: List[Ast] = List.empty + ): MethodAst = + val fakeConstructorCode = s"""{ | "type": "ClassMethod", | "key": { | "type": "Identifier", @@ -112,426 +121,512 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: | "body": [] | } |}""".stripMargin - val result = createMethodAstAndNode( - createBabelNodeInfo(ujson.read(fakeConstructorCode)), - methodBlockContent = methodBlockContent - ) - result.methodNode.code(code) - result - } - - private def findClassConstructor(clazz: BabelNodeInfo): Option[Value] = - classMembers(clazz).find(isConstructor) - - private def createClassConstructor(classExpr: BabelNodeInfo, constructorContent: List[Ast]): MethodAst = - findClassConstructor(classExpr) match { - case Some(classConstructor) if hasKey(classConstructor, "body") => - val result = - createMethodAstAndNode(createBabelNodeInfo(classConstructor), methodBlockContent = constructorContent) - diffGraph.addEdge(result.methodNode, NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), EdgeTypes.AST) - result - case Some(classConstructor) => - val methodNode = - createMethodDefinitionNode(createBabelNodeInfo(classConstructor), methodBlockContent = constructorContent) - diffGraph.addEdge(methodNode, NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), EdgeTypes.AST) - MethodAst(Ast(methodNode), methodNode, Ast(methodNode)) - case _ => - val result = createFakeConstructor("constructor() {}", classExpr, methodBlockContent = constructorContent) - diffGraph.addEdge(result.methodNode, NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), EdgeTypes.AST) + val result = createMethodAstAndNode( + createBabelNodeInfo(ujson.read(fakeConstructorCode)), + methodBlockContent = methodBlockContent + ) + result.methodNode.code(code) result - } - - private def interfaceConstructor(typeName: String, tsInterface: BabelNodeInfo): NewMethod = - findClassConstructor(tsInterface) match { - case Some(interfaceConstructor) => createMethodDefinitionNode(createBabelNodeInfo(interfaceConstructor)) - case _ => createFakeConstructor(s"new: $typeName", tsInterface).methodNode - } - - private def astsForEnumMember(tsEnumMember: BabelNodeInfo): Seq[Ast] = { - val name = code(tsEnumMember.json("id")) - val memberNode_ = memberNode(tsEnumMember, name, tsEnumMember.code, typeFor(tsEnumMember)) - addModifier(memberNode_, tsEnumMember.json) - - if (hasKey(tsEnumMember.json, "initializer")) { - val lhsAst = astForNode(tsEnumMember.json("id")) - val rhsAst = astForNodeWithFunctionReference(tsEnumMember.json("initializer")) - val callNode_ = - callNode(tsEnumMember, tsEnumMember.code, Operators.assignment, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) - Seq(callAst(callNode_, argAsts), Ast(memberNode_)) - } else { - Seq(Ast(memberNode_)) - } - } - - private def astForClassMember(classElement: Value, typeDeclNode: NewTypeDecl): Ast = { - val nodeInfo = createBabelNodeInfo(classElement) - val typeFullName = typeFor(nodeInfo) - val memberNode_ = nodeInfo.node match { - case TSDeclareMethod | TSDeclareFunction => - val function = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) - addModifier(function, nodeInfo.json) - memberNode(nodeInfo, function.name, nodeInfo.code, typeFullName, Seq(function.fullName)) - case ClassMethod | ClassPrivateMethod => - val function = createMethodAstAndNode(nodeInfo).methodNode - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) - addModifier(function, nodeInfo.json) - memberNode(nodeInfo, function.name, nodeInfo.code, typeFullName, Seq(function.fullName)) - case ExpressionStatement if isInitializedMember(classElement) => - val memberNodeInfo = createBabelNodeInfo(nodeInfo.json("expression")("left")("property")) - val name = memberNodeInfo.code - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - case TSPropertySignature | ObjectProperty if hasKey(nodeInfo.json("key"), "name") => - val memberNodeInfo = createBabelNodeInfo(nodeInfo.json("key")) - val name = memberNodeInfo.json("name").str - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - case _ => - val name = nodeInfo.node match { - case ClassProperty => code(nodeInfo.json("key")) - case ClassPrivateProperty => code(nodeInfo.json("key")("id")) - // TODO: name field most likely needs adjustment for other Babel AST types - case _ => nodeInfo.code + end createFakeConstructor + + private def findClassConstructor(clazz: BabelNodeInfo): Option[Value] = + classMembers(clazz).find(isConstructor) + + private def createClassConstructor( + classExpr: BabelNodeInfo, + constructorContent: List[Ast] + ): MethodAst = + findClassConstructor(classExpr) match + case Some(classConstructor) if hasKey(classConstructor, "body") => + val result = + createMethodAstAndNode( + createBabelNodeInfo(classConstructor), + methodBlockContent = constructorContent + ) + diffGraph.addEdge( + result.methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + result + case Some(classConstructor) => + val methodNode = + createMethodDefinitionNode( + createBabelNodeInfo(classConstructor), + methodBlockContent = constructorContent + ) + diffGraph.addEdge( + methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + MethodAst(Ast(methodNode), methodNode, Ast(methodNode)) + case _ => + val result = createFakeConstructor( + "constructor() {}", + classExpr, + methodBlockContent = constructorContent + ) + diffGraph.addEdge( + result.methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + result + + private def interfaceConstructor(typeName: String, tsInterface: BabelNodeInfo): NewMethod = + findClassConstructor(tsInterface) match + case Some(interfaceConstructor) => + createMethodDefinitionNode(createBabelNodeInfo(interfaceConstructor)) + case _ => createFakeConstructor(s"new: $typeName", tsInterface).methodNode + + private def astsForEnumMember(tsEnumMember: BabelNodeInfo): Seq[Ast] = + val name = code(tsEnumMember.json("id")) + val memberNode_ = memberNode(tsEnumMember, name, tsEnumMember.code, typeFor(tsEnumMember)) + addModifier(memberNode_, tsEnumMember.json) + + if hasKey(tsEnumMember.json, "initializer") then + val lhsAst = astForNode(tsEnumMember.json("id")) + val rhsAst = astForNodeWithFunctionReference(tsEnumMember.json("initializer")) + val callNode_ = + callNode( + tsEnumMember, + tsEnumMember.code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val argAsts = List(lhsAst, rhsAst) + Seq(callAst(callNode_, argAsts), Ast(memberNode_)) + else + Seq(Ast(memberNode_)) + end astsForEnumMember + + private def astForClassMember(classElement: Value, typeDeclNode: NewTypeDecl): Ast = + val nodeInfo = createBabelNodeInfo(classElement) + val typeFullName = typeFor(nodeInfo) + val memberNode_ = nodeInfo.node match + case TSDeclareMethod | TSDeclareFunction => + val function = createMethodDefinitionNode(nodeInfo) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + addModifier(function, nodeInfo.json) + memberNode( + nodeInfo, + function.name, + nodeInfo.code, + typeFullName, + Seq(function.fullName) + ) + case ClassMethod | ClassPrivateMethod => + val function = createMethodAstAndNode(nodeInfo).methodNode + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + addModifier(function, nodeInfo.json) + memberNode( + nodeInfo, + function.name, + nodeInfo.code, + typeFullName, + Seq(function.fullName) + ) + case ExpressionStatement if isInitializedMember(classElement) => + val memberNodeInfo = + createBabelNodeInfo(nodeInfo.json("expression")("left")("property")) + val name = memberNodeInfo.code + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + case TSPropertySignature | ObjectProperty if hasKey(nodeInfo.json("key"), "name") => + val memberNodeInfo = createBabelNodeInfo(nodeInfo.json("key")) + val name = memberNodeInfo.json("name").str + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + case _ => + val name = nodeInfo.node match + case ClassProperty => code(nodeInfo.json("key")) + case ClassPrivateProperty => code(nodeInfo.json("key")("id")) + // TODO: name field most likely needs adjustment for other Babel AST types + case _ => nodeInfo.code + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + + addModifier(memberNode_, classElement) + diffGraph.addEdge(typeDeclNode, memberNode_, EdgeTypes.AST) + astsForDecorators(nodeInfo).foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(memberNode_, _, EdgeTypes.AST)) + } + + if hasKey(nodeInfo.json, "value") && !nodeInfo.json("value").isNull then + val lhsAst = astForNode(nodeInfo.json("key")) + val rhsAst = astForNodeWithFunctionReference(nodeInfo.json("value")) + val callNode_ = + callNode( + nodeInfo, + nodeInfo.code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val argAsts = List(lhsAst, rhsAst) + callAst(callNode_, argAsts) + else + Ast() + end astForClassMember + + protected def astForEnum(tsEnum: BabelNodeInfo): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(tsEnum) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val typeDeclNode_ = typeDeclNode( + tsEnum, + typeName, + typeFullName, + parserResult.filename, + s"enum $typeName", + astParentType, + astParentFullName + ) + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, tsEnum.json) + + val typeRefNode_ = typeRefNode(tsEnum, s"enum $typeName", typeFullName) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + typeRefIdStack.push(typeRefNode_) + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val memberAsts = tsEnum.json("members").arr.toList.flatMap(m => + astsForEnumMember(createBabelNodeInfo(m)) + ) + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + typeRefIdStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member), diffGraph) + else + val init = + staticInitMethodAst( + calls, + s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.Any + ) + Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member).withChild(init), diffGraph) + + diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) + Ast(typeRefNode_) + end astForEnum + + private def isStaticMember(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + val isStatic = safeBool(json, "static").contains(true) + nodeInfo != ClassMethod && nodeInfo != ClassPrivateMethod && isStatic + + private def isInitializedMember(json: Value): Boolean = + val hasInitializedValue = hasKey(json, "value") && !json("value").isNull + val isAssignment = createBabelNodeInfo(json) match + case node if node.node == ExpressionStatement => + val exprNode = createBabelNodeInfo(node.json("expression")) + exprNode.node == AssignmentExpression && + createBabelNodeInfo(exprNode.json("left")).node == MemberExpression && + code(exprNode.json("left")("object")) == "this" + case _ => false + hasInitializedValue || isAssignment + + private def isStaticInitBlock(json: Value): Boolean = + createBabelNodeInfo(json).node == StaticBlock + + private def isClassMethodOrUninitializedMember(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + !isStaticInitBlock(json) && + (nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember(json)) + + private def isClassMethodOrUninitializedMemberOrObjectProperty(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + !isStaticInitBlock(json) && + (nodeInfo == ObjectProperty || nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember( + json + )) + + protected def astForClass( + clazz: BabelNodeInfo, + shouldCreateAssignmentCall: Boolean = false + ): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(clazz) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val superClass = Try(createBabelNodeInfo(clazz.json("superClass")).code).toOption.toSeq + val implements = Try( + clazz.json("implements").arr.map(createBabelNodeInfo(_).code) + ).toOption.toSeq.flatten + val mixins = + Try(clazz.json("mixins").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten + + val typeDeclNode_ = typeDeclNode( + clazz, + typeName, + typeFullName, + parserResult.filename, + s"class $typeName", + astParentType, + astParentFullName, + inherits = superClass ++ implements ++ mixins + ) + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, clazz.json) + astsForDecorators(clazz).foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) } - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - } - - addModifier(memberNode_, classElement) - diffGraph.addEdge(typeDeclNode, memberNode_, EdgeTypes.AST) - astsForDecorators(nodeInfo).foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(memberNode_, _, EdgeTypes.AST)) - } - - if (hasKey(nodeInfo.json, "value") && !nodeInfo.json("value").isNull) { - val lhsAst = astForNode(nodeInfo.json("key")) - val rhsAst = astForNodeWithFunctionReference(nodeInfo.json("value")) - val callNode_ = - callNode(nodeInfo, nodeInfo.code, Operators.assignment, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) - callAst(callNode_, argAsts) - } else { - Ast() - } - } - - protected def astForEnum(tsEnum: BabelNodeInfo): Ast = { - val (typeName, typeFullName) = calcTypeNameAndFullName(tsEnum) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val typeDeclNode_ = typeDeclNode( - tsEnum, - typeName, - typeFullName, - parserResult.filename, - s"enum $typeName", - astParentType, - astParentFullName - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, tsEnum.json) - - val typeRefNode_ = typeRefNode(tsEnum, s"enum $typeName", typeFullName) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - typeRefIdStack.push(typeRefNode_) - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val memberAsts = tsEnum.json("members").arr.toList.flatMap(m => astsForEnumMember(createBabelNodeInfo(m))) - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - typeRefIdStack.pop() - scope.popScope() - - val (calls, member) = memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if (calls.isEmpty) { - Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member), diffGraph) - } else { - val init = - staticInitMethodAst( - calls, - s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.Any + + diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) + + val typeRefNode_ = typeRefNode(clazz, s"class $typeName", typeFullName) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + typeRefIdStack.push(typeRefNode_) + + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val allClassMembers = classMembers(clazz, withConstructor = false).toList + + // adding all other members and retrieving their initialization calls + val memberInitCalls = allClassMembers + .filter(m => !isStaticMember(m) && isInitializedMember(m)) + .map(m => astForClassMember(m, typeDeclNode_)) + + val constructor = createClassConstructor(clazz, memberInitCalls) + val constructorNode = constructor.methodNode + + // adding all class methods / functions and uninitialized, non-static members + allClassMembers + .filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member)) + .map(m => astForClassMember(m, typeDeclNode_)) + + // adding all static members and retrieving their initialization calls + val staticMemberInitCalls = + allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_)) + + // retrieving initialization calls from the static initialization block if any + val staticInitBlock = allClassMembers.find(isStaticInitBlock) + val staticInitBlockAsts = + staticInitBlock.map(block => + block("body").arr.toList.map(astForNodeWithFunctionReference) + ).getOrElse(List.empty) + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + typeRefIdStack.pop() + scope.popScope() + + if staticMemberInitCalls.nonEmpty || staticInitBlockAsts.nonEmpty then + val init = staticInitMethodAst( + staticMemberInitCalls ++ staticInitBlockAsts, + s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.Any + ) + Ast.storeInDiffGraph(init, diffGraph) + diffGraph.addEdge(typeDeclNode_, init.nodes.head, EdgeTypes.AST) + + if shouldCreateAssignmentCall then + diffGraph.addEdge(localAstParentStack.head, typeRefNode_, EdgeTypes.AST) + + // return a synthetic assignment to enable tracing of the implicitly created identifier for + // the class definition assigned to its constructor + val classIdNode = identifierNode(clazz, typeName, Seq(constructorNode.fullName)) + val constructorRefNode = + methodRefNode( + clazz, + constructorNode.code, + constructorNode.fullName, + constructorNode.fullName + ) + + val idLocal = newLocalNode(typeName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) + scope.addVariable(typeName, idLocal, BlockScope) + scope.addVariableReference(typeName, classIdNode) + + createAssignmentCallAst( + classIdNode, + constructorRefNode, + s"$typeName = ${constructorNode.fullName}", + clazz.lineNumber, + clazz.columnNumber + ) + else + Ast(typeRefNode_) + end if + end astForClass + + protected def addModifier(node: NewNode, json: Value): Unit = + createBabelNodeInfo(json).node match + case ClassPrivateProperty => + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PRIVATE), + EdgeTypes.AST + ) + case _ => + if safeBool(json, "abstract").contains(true) then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.ABSTRACT), + EdgeTypes.AST + ) + if safeBool(json, "static").contains(true) then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.STATIC), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("public") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PUBLIC), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("private") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PRIVATE), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("protected") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PROTECTED), + EdgeTypes.AST + ) + + protected def astForModule(tsModuleDecl: BabelNodeInfo): Ast = + val (name, fullName) = calcTypeNameAndFullName(tsModuleDecl) + val namespaceNode = NewNamespaceBlock() + .code(tsModuleDecl.code) + .lineNumber(tsModuleDecl.lineNumber) + .columnNumber(tsModuleDecl.columnNumber) + .filename(parserResult.filename) + .name(name) + .fullName(fullName) + + methodAstParentStack.push(namespaceNode) + dynamicInstanceTypeStack.push(fullName) + + scope.pushNewMethodScope(fullName, name, namespaceNode, None) + + val blockAst = if hasKey(tsModuleDecl.json, "body") then + val nodeInfo = createBabelNodeInfo(tsModuleDecl.json("body")) + nodeInfo.node match + case TSModuleDeclaration => astForModule(nodeInfo) + case _ => astForBlockStatement(nodeInfo) + else + Ast() + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + scope.popScope() + + Ast(namespaceNode).withChild(blockAst) + end astForModule + + protected def astForInterface(tsInterface: BabelNodeInfo): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(tsInterface) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val extendz = Try( + tsInterface.json("extends").arr.map(createBabelNodeInfo(_).code) + ).toOption.toSeq.flatten + + val typeDeclNode_ = typeDeclNode( + tsInterface, + typeName, + typeFullName, + parserResult.filename, + s"interface $typeName", + astParentType, + astParentFullName, + inherits = extendz ) - Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member).withChild(init), diffGraph) - } - - diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) - Ast(typeRefNode_) - } - - private def isStaticMember(json: Value): Boolean = { - val nodeInfo = createBabelNodeInfo(json).node - val isStatic = safeBool(json, "static").contains(true) - nodeInfo != ClassMethod && nodeInfo != ClassPrivateMethod && isStatic - } - - private def isInitializedMember(json: Value): Boolean = { - val hasInitializedValue = hasKey(json, "value") && !json("value").isNull - val isAssignment = createBabelNodeInfo(json) match { - case node if node.node == ExpressionStatement => - val exprNode = createBabelNodeInfo(node.json("expression")) - exprNode.node == AssignmentExpression && - createBabelNodeInfo(exprNode.json("left")).node == MemberExpression && - code(exprNode.json("left")("object")) == "this" - case _ => false - } - hasInitializedValue || isAssignment - } - - private def isStaticInitBlock(json: Value): Boolean = createBabelNodeInfo(json).node == StaticBlock - - private def isClassMethodOrUninitializedMember(json: Value): Boolean = { - val nodeInfo = createBabelNodeInfo(json).node - !isStaticInitBlock(json) && - (nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember(json)) - } - - private def isClassMethodOrUninitializedMemberOrObjectProperty(json: Value): Boolean = { - val nodeInfo = createBabelNodeInfo(json).node - !isStaticInitBlock(json) && - (nodeInfo == ObjectProperty || nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember( - json - )) - } - - protected def astForClass(clazz: BabelNodeInfo, shouldCreateAssignmentCall: Boolean = false): Ast = { - val (typeName, typeFullName) = calcTypeNameAndFullName(clazz) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val superClass = Try(createBabelNodeInfo(clazz.json("superClass")).code).toOption.toSeq - val implements = Try(clazz.json("implements").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten - val mixins = Try(clazz.json("mixins").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten - - val typeDeclNode_ = typeDeclNode( - clazz, - typeName, - typeFullName, - parserResult.filename, - s"class $typeName", - astParentType, - astParentFullName, - inherits = superClass ++ implements ++ mixins - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, clazz.json) - astsForDecorators(clazz).foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) - } - - diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) - - val typeRefNode_ = typeRefNode(clazz, s"class $typeName", typeFullName) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - typeRefIdStack.push(typeRefNode_) - - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val allClassMembers = classMembers(clazz, withConstructor = false).toList - - // adding all other members and retrieving their initialization calls - val memberInitCalls = allClassMembers - .filter(m => !isStaticMember(m) && isInitializedMember(m)) - .map(m => astForClassMember(m, typeDeclNode_)) - - val constructor = createClassConstructor(clazz, memberInitCalls) - val constructorNode = constructor.methodNode - - // adding all class methods / functions and uninitialized, non-static members - allClassMembers - .filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member)) - .map(m => astForClassMember(m, typeDeclNode_)) - - // adding all static members and retrieving their initialization calls - val staticMemberInitCalls = - allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_)) - - // retrieving initialization calls from the static initialization block if any - val staticInitBlock = allClassMembers.find(isStaticInitBlock) - val staticInitBlockAsts = - staticInitBlock.map(block => block("body").arr.toList.map(astForNodeWithFunctionReference)).getOrElse(List.empty) - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - typeRefIdStack.pop() - scope.popScope() - - if (staticMemberInitCalls.nonEmpty || staticInitBlockAsts.nonEmpty) { - val init = staticInitMethodAst( - staticMemberInitCalls ++ staticInitBlockAsts, - s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.Any - ) - Ast.storeInDiffGraph(init, diffGraph) - diffGraph.addEdge(typeDeclNode_, init.nodes.head, EdgeTypes.AST) - } - - if (shouldCreateAssignmentCall) { - diffGraph.addEdge(localAstParentStack.head, typeRefNode_, EdgeTypes.AST) - - // return a synthetic assignment to enable tracing of the implicitly created identifier for - // the class definition assigned to its constructor - val classIdNode = identifierNode(clazz, typeName, Seq(constructorNode.fullName)) - val constructorRefNode = - methodRefNode(clazz, constructorNode.code, constructorNode.fullName, constructorNode.fullName) - - val idLocal = newLocalNode(typeName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) - scope.addVariable(typeName, idLocal, BlockScope) - scope.addVariableReference(typeName, classIdNode) - - createAssignmentCallAst( - classIdNode, - constructorRefNode, - s"$typeName = ${constructorNode.fullName}", - clazz.lineNumber, - clazz.columnNumber - ) - } else { - Ast(typeRefNode_) - } - } - - protected def addModifier(node: NewNode, json: Value): Unit = createBabelNodeInfo(json).node match { - case ClassPrivateProperty => - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.PRIVATE), EdgeTypes.AST) - case _ => - if (safeBool(json, "abstract").contains(true)) - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.ABSTRACT), EdgeTypes.AST) - if (safeBool(json, "static").contains(true)) - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.STATIC), EdgeTypes.AST) - if (safeStr(json, "accessibility").contains("public")) - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.PUBLIC), EdgeTypes.AST) - if (safeStr(json, "accessibility").contains("private")) - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.PRIVATE), EdgeTypes.AST) - if (safeStr(json, "accessibility").contains("protected")) - diffGraph.addEdge(node, NewModifier().modifierType(ModifierTypes.PROTECTED), EdgeTypes.AST) - } - - protected def astForModule(tsModuleDecl: BabelNodeInfo): Ast = { - val (name, fullName) = calcTypeNameAndFullName(tsModuleDecl) - val namespaceNode = NewNamespaceBlock() - .code(tsModuleDecl.code) - .lineNumber(tsModuleDecl.lineNumber) - .columnNumber(tsModuleDecl.columnNumber) - .filename(parserResult.filename) - .name(name) - .fullName(fullName) - - methodAstParentStack.push(namespaceNode) - dynamicInstanceTypeStack.push(fullName) - - scope.pushNewMethodScope(fullName, name, namespaceNode, None) - - val blockAst = if (hasKey(tsModuleDecl.json, "body")) { - val nodeInfo = createBabelNodeInfo(tsModuleDecl.json("body")) - nodeInfo.node match { - case TSModuleDeclaration => astForModule(nodeInfo) - case _ => astForBlockStatement(nodeInfo) - } - } else { - Ast() - } - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - scope.popScope() - - Ast(namespaceNode).withChild(blockAst) - } - - protected def astForInterface(tsInterface: BabelNodeInfo): Ast = { - val (typeName, typeFullName) = calcTypeNameAndFullName(tsInterface) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val extendz = Try(tsInterface.json("extends").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten - - val typeDeclNode_ = typeDeclNode( - tsInterface, - typeName, - typeFullName, - parserResult.filename, - s"interface $typeName", - astParentType, - astParentFullName, - inherits = extendz - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, tsInterface.json) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val constructorNode = interfaceConstructor(typeName, tsInterface) - diffGraph.addEdge(constructorNode, NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), EdgeTypes.AST) - - val constructorBindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, constructorBindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(constructorBindingNode, constructorNode, EdgeTypes.REF) - - val interfaceBodyElements = classMembers(tsInterface, withConstructor = false) - - interfaceBodyElements.foreach { classElement => - val nodeInfo = createBabelNodeInfo(classElement) - val typeFullName = typeFor(nodeInfo) - val memberNodes = nodeInfo.node match { - case TSCallSignatureDeclaration | TSMethodSignature => - val functionNode = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) - addModifier(functionNode, nodeInfo.json) - Seq(memberNode(nodeInfo, functionNode.name, nodeInfo.code, typeFullName, Seq(functionNode.fullName))) - case _ => - val names = nodeInfo.node match { - case TSPropertySignature | TSMethodSignature => - if (hasKey(nodeInfo.json("key"), "value")) { - Seq(safeStr(nodeInfo.json("key"), "value").getOrElse(code(nodeInfo.json("key")("value")))) - } else Seq(code(nodeInfo.json("key"))) - case TSIndexSignature => - nodeInfo.json("parameters").arr.toSeq.map(_("name").str) - // TODO: name field most likely needs adjustment for other Babel AST types - case _ => Seq(nodeInfo.code) - } - names.map { n => - val node = memberNode(nodeInfo, n, nodeInfo.code, typeFullName) - addModifier(node, nodeInfo.json) - node - } - } - memberNodes.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) - } - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - scope.popScope() - - Ast(typeDeclNode_) - } - -} + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, tsInterface.json) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val constructorNode = interfaceConstructor(typeName, tsInterface) + diffGraph.addEdge( + constructorNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + + val constructorBindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode_, constructorBindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(constructorBindingNode, constructorNode, EdgeTypes.REF) + + val interfaceBodyElements = classMembers(tsInterface, withConstructor = false) + + interfaceBodyElements.foreach { classElement => + val nodeInfo = createBabelNodeInfo(classElement) + val typeFullName = typeFor(nodeInfo) + val memberNodes = nodeInfo.node match + case TSCallSignatureDeclaration | TSMethodSignature => + val functionNode = createMethodDefinitionNode(nodeInfo) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) + addModifier(functionNode, nodeInfo.json) + Seq(memberNode( + nodeInfo, + functionNode.name, + nodeInfo.code, + typeFullName, + Seq(functionNode.fullName) + )) + case _ => + val names = nodeInfo.node match + case TSPropertySignature | TSMethodSignature => + if hasKey(nodeInfo.json("key"), "value") then + Seq(safeStr(nodeInfo.json("key"), "value").getOrElse( + code(nodeInfo.json("key")("value")) + )) + else Seq(code(nodeInfo.json("key"))) + case TSIndexSignature => + nodeInfo.json("parameters").arr.toSeq.map(_("name").str) + // TODO: name field most likely needs adjustment for other Babel AST types + case _ => Seq(nodeInfo.code) + names.map { n => + val node = memberNode(nodeInfo, n, nodeInfo.code, typeFullName) + addModifier(node, nodeInfo.json) + node + } + memberNodes.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) + } + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + scope.popScope() + + Ast(typeDeclNode_) + end astForInterface +end AstForTypesCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala index b59ef241..050584a3 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala @@ -9,283 +9,290 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.Operators -trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def createMethodReturnNode(func: BabelNodeInfo): NewMethodReturn = { - newMethodReturnNode(typeFor(func), line = func.lineNumber, column = func.columnNumber) - } +trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode): + this: AstCreator => + protected def createMethodReturnNode(func: BabelNodeInfo): NewMethodReturn = + newMethodReturnNode(typeFor(func), line = func.lineNumber, column = func.columnNumber) - protected def setOrderExplicitly(ast: Ast, order: Int): Unit = { - ast.root.foreach { case expr: ExpressionNew => expr.order = order } - } + protected def setOrderExplicitly(ast: Ast, order: Int): Unit = + ast.root.foreach { case expr: ExpressionNew => expr.order = order } - protected def createJumpTarget(switchCase: BabelNodeInfo): NewJumpTarget = { - val (switchName, switchCode) = if (switchCase.json("test").isNull) { - ("default", "default:") - } else { - ("case", s"case ${code(switchCase.json("test"))}:") - } - NewJumpTarget() - .parserTypeName(switchCase.node.toString) - .name(switchName) - .code(switchCode) - .lineNumber(switchCase.lineNumber) - .columnNumber(switchCase.columnNumber) - } + protected def createJumpTarget(switchCase: BabelNodeInfo): NewJumpTarget = + val (switchName, switchCode) = if switchCase.json("test").isNull then + ("default", "default:") + else + ("case", s"case ${code(switchCase.json("test"))}:") + NewJumpTarget() + .parserTypeName(switchCase.node.toString) + .name(switchName) + .code(switchCode) + .lineNumber(switchCase.lineNumber) + .columnNumber(switchCase.columnNumber) - protected def createControlStructureNode(node: BabelNodeInfo, controlStructureType: String): NewControlStructure = { - val line = node.lineNumber - val column = node.columnNumber - val code = node.code - NewControlStructure() - .controlStructureType(controlStructureType) - .code(code) - .lineNumber(line) - .columnNumber(column) - } + protected def createControlStructureNode( + node: BabelNodeInfo, + controlStructureType: String + ): NewControlStructure = + val line = node.lineNumber + val column = node.columnNumber + val code = node.code + NewControlStructure() + .controlStructureType(controlStructureType) + .code(code) + .lineNumber(line) + .columnNumber(column) - protected def codeOf(node: NewNode): String = node match { - case astNodeNew: AstNodeNew => astNodeNew.code - case _ => "" - } + protected def codeOf(node: NewNode): String = node match + case astNodeNew: AstNodeNew => astNodeNew.code + case _ => "" - protected def createIndexAccessCallAst( - baseNode: NewNode, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode( - s"${codeOf(baseNode)}[${codeOf(partNode)}]", - Operators.indexAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(Ast(baseNode), Ast(partNode)) - callAst(callNode, arguments) - } + protected def createIndexAccessCallAst( + baseNode: NewNode, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseNode)}[${codeOf(partNode)}]", + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(Ast(baseNode), Ast(partNode)) + callAst(callNode, arguments) - protected def createIndexAccessCallAst( - baseAst: Ast, - partAst: Ast, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode( - s"${codeOf(baseAst.nodes.head)}[${codeOf(partAst.nodes.head)}]", - Operators.indexAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(baseAst, partAst) - callAst(callNode, arguments) - } + protected def createIndexAccessCallAst( + baseAst: Ast, + partAst: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseAst.nodes.head)}[${codeOf(partAst.nodes.head)}]", + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(baseAst, partAst) + callAst(callNode, arguments) - protected def createFieldAccessCallAst( - baseNode: NewNode, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode( - s"${codeOf(baseNode)}.${codeOf(partNode)}", - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(Ast(baseNode), Ast(partNode)) - callAst(callNode, arguments) - } + protected def createFieldAccessCallAst( + baseNode: NewNode, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseNode)}.${codeOf(partNode)}", + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(Ast(baseNode), Ast(partNode)) + callAst(callNode, arguments) - protected def createFieldAccessCallAst( - baseAst: Ast, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode( - s"${codeOf(baseAst.nodes.head)}.${codeOf(partNode)}", - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(baseAst, Ast(partNode)) - callAst(callNode, arguments) - } + protected def createFieldAccessCallAst( + baseAst: Ast, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseAst.nodes.head)}.${codeOf(partNode)}", + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(baseAst, Ast(partNode)) + callAst(callNode, arguments) - protected def createTernaryCallAst( - testAst: Ast, - trueAst: Ast, - falseAst: Ast, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val code = s"${codeOf(testAst.nodes.head)} ? ${codeOf(trueAst.nodes.head)} : ${codeOf(falseAst.nodes.head)}" - val callNode = createCallNode(code, Operators.conditional, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(testAst, trueAst, falseAst) - callAst(callNode, arguments) - } + protected def createTernaryCallAst( + testAst: Ast, + trueAst: Ast, + falseAst: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val code = + s"${codeOf(testAst.nodes.head)} ? ${codeOf(trueAst.nodes.head)} : ${codeOf(falseAst.nodes.head)}" + val callNode = + createCallNode(code, Operators.conditional, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(testAst, trueAst, falseAst) + callAst(callNode, arguments) - def callNode(node: BabelNodeInfo, code: String, name: String, dispatchType: String): NewCall = { - val fullName = - if (dispatchType == DispatchTypes.STATIC_DISPATCH) name - else x2cpg.Defines.DynamicCallUnknownFullName - callNode(node, code, name, fullName, dispatchType, None, Some(Defines.Any)) - } + def callNode(node: BabelNodeInfo, code: String, name: String, dispatchType: String): NewCall = + val fullName = + if dispatchType == DispatchTypes.STATIC_DISPATCH then name + else x2cpg.Defines.DynamicCallUnknownFullName + callNode(node, code, name, fullName, dispatchType, None, Some(Defines.Any)) - private def createCallNode( - code: String, - callName: String, - dispatchType: String, - line: Option[Integer], - column: Option[Integer] - ): NewCall = NewCall() - .code(code) - .name(callName) - .methodFullName( - if (dispatchType == DispatchTypes.STATIC_DISPATCH) callName else x2cpg.Defines.DynamicCallUnknownFullName - ) - .dispatchType(dispatchType) - .lineNumber(line) - .columnNumber(column) - .typeFullName(Defines.Any) + private def createCallNode( + code: String, + callName: String, + dispatchType: String, + line: Option[Integer], + column: Option[Integer] + ): NewCall = NewCall() + .code(code) + .name(callName) + .methodFullName( + if dispatchType == DispatchTypes.STATIC_DISPATCH then callName + else x2cpg.Defines.DynamicCallUnknownFullName + ) + .dispatchType(dispatchType) + .lineNumber(line) + .columnNumber(column) + .typeFullName(Defines.Any) - protected def createVoidCallNode(line: Option[Integer], column: Option[Integer]): NewCall = - createCallNode("void 0", ".void", DispatchTypes.STATIC_DISPATCH, line, column) + protected def createVoidCallNode(line: Option[Integer], column: Option[Integer]): NewCall = + createCallNode("void 0", ".void", DispatchTypes.STATIC_DISPATCH, line, column) - protected def createFieldIdentifierNode( - name: String, - line: Option[Integer], - column: Option[Integer] - ): NewFieldIdentifier = { - val cleanedName = stripQuotes(name) - NewFieldIdentifier() - .code(cleanedName) - .canonicalName(cleanedName) - .lineNumber(line) - .columnNumber(column) - } + protected def createFieldIdentifierNode( + name: String, + line: Option[Integer], + column: Option[Integer] + ): NewFieldIdentifier = + val cleanedName = stripQuotes(name) + NewFieldIdentifier() + .code(cleanedName) + .canonicalName(cleanedName) + .lineNumber(line) + .columnNumber(column) - protected def literalNode(node: BabelNodeInfo, code: String, dynamicTypeOption: Option[String]): NewLiteral = { - val typeFullName = dynamicTypeOption match { - case Some(value) if Defines.JsTypes.contains(value) => value - case _ => Defines.Any - } - literalNode(node, code, typeFullName, dynamicTypeOption.toList) - } + protected def literalNode( + node: BabelNodeInfo, + code: String, + dynamicTypeOption: Option[String] + ): NewLiteral = + val typeFullName = dynamicTypeOption match + case Some(value) if Defines.JsTypes.contains(value) => value + case _ => Defines.Any + literalNode(node, code, typeFullName, dynamicTypeOption.toList) - protected def createEqualsCallAst(dest: Ast, source: Ast, line: Option[Integer], column: Option[Integer]): Ast = { - val code = s"${codeOf(dest.nodes.head)} === ${codeOf(source.nodes.head)}" - val callNode = createCallNode(code, Operators.equals, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(dest, source) - callAst(callNode, arguments) - } + protected def createEqualsCallAst( + dest: Ast, + source: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val code = s"${codeOf(dest.nodes.head)} === ${codeOf(source.nodes.head)}" + val callNode = + createCallNode(code, Operators.equals, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(dest, source) + callAst(callNode, arguments) - protected def createAssignmentCallAst( - destId: NewNode, - sourceId: NewNode, - code: String, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(Ast(destId), Ast(sourceId)) - callAst(callNode, arguments) - } + protected def createAssignmentCallAst( + destId: NewNode, + sourceId: NewNode, + code: String, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = + createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(Ast(destId), Ast(sourceId)) + callAst(callNode, arguments) - protected def createAssignmentCallAst( - dest: Ast, - source: Ast, - code: String, - line: Option[Integer], - column: Option[Integer] - ): Ast = { - val callNode = createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(dest, source) - callAst(callNode, arguments) - } + protected def createAssignmentCallAst( + dest: Ast, + source: Ast, + code: String, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = + createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(dest, source) + callAst(callNode, arguments) - protected def identifierNode(node: BabelNodeInfo, name: String): NewIdentifier = { - val dynamicInstanceTypeOption = name match { - case "this" => typeHintForThisExpression(Option(node)).headOption - case "console" => Option(Defines.Console) - case "Math" => Option(Defines.Math) - case _ => None - } - identifierNode(node, name, name, Defines.Any, dynamicInstanceTypeOption.toList) - } + protected def identifierNode(node: BabelNodeInfo, name: String): NewIdentifier = + val dynamicInstanceTypeOption = name match + case "this" => typeHintForThisExpression(Option(node)).headOption + case "console" => Option(Defines.Console) + case "Math" => Option(Defines.Math) + case _ => None + identifierNode(node, name, name, Defines.Any, dynamicInstanceTypeOption.toList) - protected def identifierNode(node: BabelNodeInfo, name: String, dynamicTypeHints: Seq[String]): NewIdentifier = { - identifierNode(node, name, name, Defines.Any, dynamicTypeHints) - } + protected def identifierNode( + node: BabelNodeInfo, + name: String, + dynamicTypeHints: Seq[String] + ): NewIdentifier = + identifierNode(node, name, name, Defines.Any, dynamicTypeHints) - protected def createStaticCallNode( - code: String, - callName: String, - fullName: String, - line: Option[Integer], - column: Option[Integer] - ): NewCall = NewCall() - .code(code) - .name(callName) - .methodFullName(fullName) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .signature("") - .lineNumber(line) - .columnNumber(column) - .typeFullName(Defines.Any) + protected def createStaticCallNode( + code: String, + callName: String, + fullName: String, + line: Option[Integer], + column: Option[Integer] + ): NewCall = NewCall() + .code(code) + .name(callName) + .methodFullName(fullName) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .signature("") + .lineNumber(line) + .columnNumber(column) + .typeFullName(Defines.Any) - protected def createTemplateDomNode( - name: String, - code: String, - line: Option[Integer], - column: Option[Integer] - ): NewTemplateDom = - NewTemplateDom() - .name(name) - .code(code) - .lineNumber(line) - .columnNumber(column) + protected def createTemplateDomNode( + name: String, + code: String, + line: Option[Integer], + column: Option[Integer] + ): NewTemplateDom = + NewTemplateDom() + .name(name) + .code(code) + .lineNumber(line) + .columnNumber(column) - protected def createBlockNode(node: BabelNodeInfo, customCode: Option[String] = None): NewBlock = - NewBlock() - .typeFullName(Defines.Any) - .code(customCode.getOrElse(node.code)) - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) + protected def createBlockNode( + node: BabelNodeInfo, + customCode: Option[String] = None + ): NewBlock = + NewBlock() + .typeFullName(Defines.Any) + .code(customCode.getOrElse(node.code)) + .lineNumber(node.lineNumber) + .columnNumber(node.columnNumber) - protected def createFunctionTypeAndTypeDeclAst( - node: BabelNodeInfo, - methodNode: NewMethod, - parentNode: NewNode, - methodName: String, - methodFullName: String, - filename: String - ): Ast = { - registerType(methodName, methodFullName) + protected def createFunctionTypeAndTypeDeclAst( + node: BabelNodeInfo, + methodNode: NewMethod, + parentNode: NewNode, + methodName: String, + methodFullName: String, + filename: String + ): Ast = + registerType(methodName, methodFullName) - val astParentType = parentNode.label - val astParentFullName = parentNode.properties("FULL_NAME").toString - val functionTypeDeclNode = - typeDeclNode( - node, - methodName, - methodFullName, - filename, - methodName, - astParentType = astParentType, - astParentFullName = astParentFullName, - List(Defines.Any) - ) + val astParentType = parentNode.label + val astParentFullName = parentNode.properties("FULL_NAME").toString + val functionTypeDeclNode = + typeDeclNode( + node, + methodName, + methodFullName, + filename, + methodName, + astParentType = astParentType, + astParentFullName = astParentFullName, + List(Defines.Any) + ) - // Problem for https://github.com/ShiftLeftSecurity/codescience/issues/3626 here. - // As the type (thus, the signature) of the function node is unknown (i.e., ANY*) - // we can't generate the correct binding with signature. - val bindingNode = NewBinding().name("").signature("") - Ast(functionTypeDeclNode).withBindsEdge(functionTypeDeclNode, bindingNode).withRefEdge(bindingNode, methodNode) - } - -} + // Problem for https://github.com/ShiftLeftSecurity/codescience/issues/3626 here. + // As the type (thus, the signature) of the function node is unknown (i.e., ANY*) + // we can't generate the correct binding with signature. + val bindingNode = NewBinding().name("").signature("") + Ast(functionTypeDeclNode).withBindsEdge(functionTypeDeclNode, bindingNode).withRefEdge( + bindingNode, + methodNode + ) + end createFunctionTypeAndTypeDeclAst +end AstNodeBuilder diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala index b62c977e..7f84e1f3 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala @@ -6,137 +6,138 @@ import io.appthreat.jssrc2cpg.parser.BabelAst.* import java.util.regex.Pattern -trait TypeHelper { this: AstCreator => - - private val TypeAnnotationKey = "typeAnnotation" - private val ReturnTypeKey = "returnType" - private val ImportMatcher = Pattern.compile("(typeof )?import\\([\"'](.*)[\"']\\)") - - private val ArrayReplacements = Map( - "any[]" -> s"${Defines.Any}[]", - "unknown[]" -> s"${Defines.Unknown}[]", - "number[]" -> s"${Defines.Number}[]", - "string[]" -> s"${Defines.String}[]", - "boolean[]" -> s"${Defines.Boolean}[]" - ) - - private val TypeReplacements = Map( - " any" -> s" ${Defines.Any}", - " unknown" -> s" ${Defines.Unknown}", - " number" -> s" ${Defines.Number}", - " null" -> s" ${Defines.Null}", - " string" -> s" ${Defines.String}", - " boolean" -> s" ${Defines.Boolean}", - " bigint" -> s" ${Defines.BigInt}", - "{}" -> Defines.Object, - "typeof " -> "" - ) - - protected def isPlainTypeAlias(alias: BabelNodeInfo): Boolean = if (hasKey(alias.json, "right")) { - createBabelNodeInfo(alias.json("right")).node.toString == TSTypeReference.toString - } else { - createBabelNodeInfo(alias.json("typeAnnotation")).node.toString == TSTypeReference.toString - } - - private def typeForFlowType(flowType: BabelNodeInfo): String = flowType.node match { - case BooleanTypeAnnotation => Defines.Boolean - case NumberTypeAnnotation => Defines.Number - case ObjectTypeAnnotation => Defines.Object - case StringTypeAnnotation => Defines.String - case SymbolTypeAnnotation => Defines.Symbol - case NumberLiteralTypeAnnotation => code(flowType.json) - case ArrayTypeAnnotation => code(flowType.json) - case BooleanLiteralTypeAnnotation => code(flowType.json) - case NullLiteralTypeAnnotation => code(flowType.json) - case StringLiteralTypeAnnotation => code(flowType.json) - case GenericTypeAnnotation => code(flowType.json("id")) - case ThisTypeAnnotation => typeHintForThisExpression(Option(flowType)).headOption.getOrElse(Defines.Any); - case NullableTypeAnnotation => typeForTypeAnnotation(createBabelNodeInfo(flowType.json(TypeAnnotationKey))) - case _ => Defines.Any - } - - private def typeForTsType(tsType: BabelNodeInfo): String = tsType.node match { - case TSBooleanKeyword => Defines.Boolean - case TSBigIntKeyword => Defines.Number - case TSNullKeyword => Defines.Null - case TSNumberKeyword => Defines.Number - case TSObjectKeyword => Defines.Object - case TSStringKeyword => Defines.String - case TSSymbolKeyword => Defines.Symbol - case TSUnknownKeyword => Defines.Unknown - case TSVoidKeyword => Defines.Void - case TSUndefinedKeyword => Defines.Undefined - case TSNeverKeyword => Defines.Never - case TSIntrinsicKeyword => code(tsType.json) - case TSTypeReference => code(tsType.json) - case TSArrayType => code(tsType.json) - case TSThisType => typeHintForThisExpression(Option(tsType)).headOption.getOrElse(Defines.Any) - case TSOptionalType => typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case TSRestType => typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case TSParenthesizedType => typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case _ => Defines.Any - } - - private def typeForTypeAnnotation(typeAnnotation: BabelNodeInfo): String = typeAnnotation.node match { - case TypeAnnotation => typeForFlowType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) - case TSTypeAnnotation => typeForTsType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) - case _: FlowType => typeForFlowType(createBabelNodeInfo(typeAnnotation.json)) - case _: TSType => typeForTsType(createBabelNodeInfo(typeAnnotation.json)) - case _ => Defines.Any - } - - private def isStringType(tpe: String): Boolean = - tpe.startsWith("\"") && tpe.endsWith("\"") - - private def isNumberType(tpe: String): Boolean = - tpe.toDoubleOption.isDefined - - private def typeFromTypeMap(node: BabelNodeInfo): String = - pos(node.json).flatMap(parserResult.typeMap.get) match { - case Some(value) if value.isEmpty => Defines.String - case Some(value) if value == "string" => Defines.String - case Some(value) if isStringType(value) => Defines.String - case Some(value) if value == "number" => Defines.Number - case Some(value) if isNumberType(value) => Defines.Number - case Some(value) if value == "null" => Defines.Null - case Some(value) if value == "boolean" => Defines.Boolean - case Some(value) if value == "any" => Defines.Any - case Some(value) if ImportMatcher.matcher(value).matches() => importToModule(value) - case Some(other) => - (TypeReplacements ++ ArrayReplacements).foldLeft(other) { case (typeStr, (m, r)) => - typeStr.replace(m, r) - } - case None => Defines.Any - } - - private def importToModule(value: String): String = { - val matcher = ImportMatcher.matcher(value) - this.rootTypeDecl.headOption match { - case Some(typeDecl) => typeDecl.fullName - case None if matcher.matches() => matcher.group(2).stripSuffix(".js").concat(".js::program") - case None => value - } - } - - protected def typeFor(node: BabelNodeInfo): String = { - val tpe = Seq(TypeAnnotationKey, ReturnTypeKey).find(hasKey(node.json, _)) match { - case Some(key) => typeForTypeAnnotation(createBabelNodeInfo(node.json(key))) - case None => typeFromTypeMap(node) - } - registerType(tpe, tpe) - tpe - } - - protected def typeHintForThisExpression(node: Option[BabelNodeInfo] = None): Seq[String] = { - dynamicInstanceTypeStack.headOption match { - case Some(tpe) => Seq(tpe) - case None if node.isDefined => - typeFor(node.get) match { - case t if t != Defines.Any && t != "this" => Seq(t) - case _ => rootTypeDecl.map(_.fullName).toSeq - } - case None => rootTypeDecl.map(_.fullName).toSeq - } - } - -} +trait TypeHelper: + this: AstCreator => + + private val TypeAnnotationKey = "typeAnnotation" + private val ReturnTypeKey = "returnType" + private val ImportMatcher = Pattern.compile("(typeof )?import\\([\"'](.*)[\"']\\)") + + private val ArrayReplacements = Map( + "any[]" -> s"${Defines.Any}[]", + "unknown[]" -> s"${Defines.Unknown}[]", + "number[]" -> s"${Defines.Number}[]", + "string[]" -> s"${Defines.String}[]", + "boolean[]" -> s"${Defines.Boolean}[]" + ) + + private val TypeReplacements = Map( + " any" -> s" ${Defines.Any}", + " unknown" -> s" ${Defines.Unknown}", + " number" -> s" ${Defines.Number}", + " null" -> s" ${Defines.Null}", + " string" -> s" ${Defines.String}", + " boolean" -> s" ${Defines.Boolean}", + " bigint" -> s" ${Defines.BigInt}", + "{}" -> Defines.Object, + "typeof " -> "" + ) + + protected def isPlainTypeAlias(alias: BabelNodeInfo): Boolean = + if hasKey(alias.json, "right") then + createBabelNodeInfo(alias.json("right")).node.toString == TSTypeReference.toString + else + createBabelNodeInfo( + alias.json("typeAnnotation") + ).node.toString == TSTypeReference.toString + + private def typeForFlowType(flowType: BabelNodeInfo): String = flowType.node match + case BooleanTypeAnnotation => Defines.Boolean + case NumberTypeAnnotation => Defines.Number + case ObjectTypeAnnotation => Defines.Object + case StringTypeAnnotation => Defines.String + case SymbolTypeAnnotation => Defines.Symbol + case NumberLiteralTypeAnnotation => code(flowType.json) + case ArrayTypeAnnotation => code(flowType.json) + case BooleanLiteralTypeAnnotation => code(flowType.json) + case NullLiteralTypeAnnotation => code(flowType.json) + case StringLiteralTypeAnnotation => code(flowType.json) + case GenericTypeAnnotation => code(flowType.json("id")) + case ThisTypeAnnotation => + typeHintForThisExpression(Option(flowType)).headOption.getOrElse(Defines.Any); + case NullableTypeAnnotation => + typeForTypeAnnotation(createBabelNodeInfo(flowType.json(TypeAnnotationKey))) + case _ => Defines.Any + + private def typeForTsType(tsType: BabelNodeInfo): String = tsType.node match + case TSBooleanKeyword => Defines.Boolean + case TSBigIntKeyword => Defines.Number + case TSNullKeyword => Defines.Null + case TSNumberKeyword => Defines.Number + case TSObjectKeyword => Defines.Object + case TSStringKeyword => Defines.String + case TSSymbolKeyword => Defines.Symbol + case TSUnknownKeyword => Defines.Unknown + case TSVoidKeyword => Defines.Void + case TSUndefinedKeyword => Defines.Undefined + case TSNeverKeyword => Defines.Never + case TSIntrinsicKeyword => code(tsType.json) + case TSTypeReference => code(tsType.json) + case TSArrayType => code(tsType.json) + case TSThisType => + typeHintForThisExpression(Option(tsType)).headOption.getOrElse(Defines.Any) + case TSOptionalType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case TSRestType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case TSParenthesizedType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case _ => Defines.Any + + private def typeForTypeAnnotation(typeAnnotation: BabelNodeInfo): String = + typeAnnotation.node match + case TypeAnnotation => + typeForFlowType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) + case TSTypeAnnotation => + typeForTsType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) + case _: FlowType => typeForFlowType(createBabelNodeInfo(typeAnnotation.json)) + case _: TSType => typeForTsType(createBabelNodeInfo(typeAnnotation.json)) + case _ => Defines.Any + + private def isStringType(tpe: String): Boolean = + tpe.startsWith("\"") && tpe.endsWith("\"") + + private def isNumberType(tpe: String): Boolean = + tpe.toDoubleOption.isDefined + + private def typeFromTypeMap(node: BabelNodeInfo): String = + pos(node.json).flatMap(parserResult.typeMap.get) match + case Some(value) if value.isEmpty => Defines.String + case Some(value) if value == "string" => Defines.String + case Some(value) if isStringType(value) => Defines.String + case Some(value) if value == "number" => Defines.Number + case Some(value) if isNumberType(value) => Defines.Number + case Some(value) if value == "null" => Defines.Null + case Some(value) if value == "boolean" => Defines.Boolean + case Some(value) if value == "any" => Defines.Any + case Some(value) if ImportMatcher.matcher(value).matches() => importToModule(value) + case Some(other) => + (TypeReplacements ++ ArrayReplacements).foldLeft(other) { case (typeStr, (m, r)) => + typeStr.replace(m, r) + } + case None => Defines.Any + + private def importToModule(value: String): String = + val matcher = ImportMatcher.matcher(value) + this.rootTypeDecl.headOption match + case Some(typeDecl) => typeDecl.fullName + case None if matcher.matches() => + matcher.group(2).stripSuffix(".js").concat(".js::program") + case None => value + + protected def typeFor(node: BabelNodeInfo): String = + val tpe = Seq(TypeAnnotationKey, ReturnTypeKey).find(hasKey(node.json, _)) match + case Some(key) => typeForTypeAnnotation(createBabelNodeInfo(node.json(key))) + case None => typeFromTypeMap(node) + registerType(tpe, tpe) + tpe + + protected def typeHintForThisExpression(node: Option[BabelNodeInfo] = None): Seq[String] = + dynamicInstanceTypeStack.headOption match + case Some(tpe) => Seq(tpe) + case None if node.isDefined => + typeFor(node.get) match + case t if t != Defines.Any && t != "this" => Seq(t) + case _ => rootTypeDecl.map(_.fullName).toSeq + case None => rootTypeDecl.map(_.fullName).toSeq +end TypeHelper diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala index 770d5265..5e85539d 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala @@ -2,20 +2,20 @@ package io.appthreat.jssrc2cpg.datastructures import io.shiftleft.codepropertygraph.generated.nodes.NewNode -case class PendingReference(variableName: String, referenceNode: NewNode, stack: Option[ScopeElement]) { +case class PendingReference( + variableName: String, + referenceNode: NewNode, + stack: Option[ScopeElement] +): - def tryResolve(): Option[ResolvedReference] = { - var foundVariableOption = Option.empty[NewNode] - val stackIterator = new ScopeElementIterator(stack) + def tryResolve(): Option[ResolvedReference] = + var foundVariableOption = Option.empty[NewNode] + val stackIterator = new ScopeElementIterator(stack) - while (stackIterator.hasNext && foundVariableOption.isEmpty) { - val scopeElement = stackIterator.next() - foundVariableOption = scopeElement.nameToVariableNode.get(variableName) - } + while stackIterator.hasNext && foundVariableOption.isEmpty do + val scopeElement = stackIterator.next() + foundVariableOption = scopeElement.nameToVariableNode.get(variableName) - foundVariableOption.map { variableNodeId => - ResolvedReference(variableNodeId, this) - } - } - -} + foundVariableOption.map { variableNodeId => + ResolvedReference(variableNodeId, this) + } diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala index a92e4f94..8b053649 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala @@ -6,103 +6,103 @@ import scala.collection.mutable /** Handles the scope stack for tracking identifier to variable relation. */ -class Scope { - - private val pendingReferences: mutable.Buffer[PendingReference] = - mutable.ListBuffer.empty[PendingReference] - - private var stack = Option.empty[ScopeElement] - - def getScopeHead: Option[ScopeElement] = stack - - def isEmpty: Boolean = stack.isEmpty - - def pushNewMethodScope( - methodFullName: String, - name: String, - scopeNode: NewNode, - capturingRefId: Option[NewNode] - ): Unit = - stack = Option(new MethodScopeElement(methodFullName, capturingRefId, name, scopeNode, surroundingScope = stack)) - - def pushNewBlockScope(scopeNode: NewNode): Unit = { - peek match { - case Some(stackTop) => - stack = Option(new BlockScopeElement(stackTop.subScopeCounter.toString, scopeNode, surroundingScope = stack)) - stackTop.subScopeCounter += 1 - case None => - stack = Option(new BlockScopeElement("0", scopeNode, surroundingScope = stack)) - } - } - - private def peek: Option[ScopeElement] = { - stack - } - - def popScope(): Unit = { - stack = stack.get.surroundingScope - } - - def addVariable(variableName: String, variableNode: NewNode, scopeType: ScopeType): Unit = { - addVariable(stack, variableName, variableNode, scopeType) - } - - def addVariableReference(variableName: String, referenceNode: NewNode): Unit = { - pendingReferences prepend PendingReference(variableName, referenceNode, stack) - } - - def resolve(unresolvedHandler: (NewNode, String) => (NewNode, ScopeType)): Iterator[ResolvedReference] = { - pendingReferences.iterator.map { pendingReference => - val resolvedReferenceOption = pendingReference.tryResolve() - - resolvedReferenceOption.getOrElse { - val methodScopeNode = Scope.getEnclosingMethodScopeNode(pendingReference.stack) - val (newVariableNode, scopeType) = - unresolvedHandler(methodScopeNode, pendingReference.variableName) - addVariable(pendingReference.stack, pendingReference.variableName, newVariableNode, scopeType) - pendingReference.tryResolve().get - } - } - - } - - private def addVariable( - stack: Option[ScopeElement], - variableName: String, - variableNode: NewNode, - scopeType: ScopeType - ): Unit = { - val scopeToAddTo = scopeType match { - case MethodScope => Scope.getEnclosingMethodScopeElement(stack) - case _ => stack.get - } - scopeToAddTo.addVariable(variableName, variableNode) - } - -} - -object Scope { - private def getEnclosingMethodScopeNode(scopeHead: Option[ScopeElement]): NewNode = - getEnclosingMethodScopeElement(scopeHead).scopeNode - - def getEnclosingMethodScopeElement(scopeHead: Option[ScopeElement]): MethodScopeElement = { - // There are no references outside of methods. Meaning we always find a MethodScope here. - new ScopeElementIterator(scopeHead) - .collectFirst { case methodScopeElement: MethodScopeElement => methodScopeElement } - .getOrElse(throw new RuntimeException("Cannot find method scope.")) - } -} - -class ScopeElementIterator(start: Option[ScopeElement]) extends Iterator[ScopeElement] { - private var currentScopeElement = start - - override def hasNext: Boolean = { - currentScopeElement.isDefined - } - - override def next(): ScopeElement = { - val result = currentScopeElement.get - currentScopeElement = result.surroundingScope - result - } -} +class Scope: + + private val pendingReferences: mutable.Buffer[PendingReference] = + mutable.ListBuffer.empty[PendingReference] + + private var stack = Option.empty[ScopeElement] + + def getScopeHead: Option[ScopeElement] = stack + + def isEmpty: Boolean = stack.isEmpty + + def pushNewMethodScope( + methodFullName: String, + name: String, + scopeNode: NewNode, + capturingRefId: Option[NewNode] + ): Unit = + stack = Option(new MethodScopeElement( + methodFullName, + capturingRefId, + name, + scopeNode, + surroundingScope = stack + )) + + def pushNewBlockScope(scopeNode: NewNode): Unit = + peek match + case Some(stackTop) => + stack = Option(new BlockScopeElement( + stackTop.subScopeCounter.toString, + scopeNode, + surroundingScope = stack + )) + stackTop.subScopeCounter += 1 + case None => + stack = Option(new BlockScopeElement("0", scopeNode, surroundingScope = stack)) + + private def peek: Option[ScopeElement] = + stack + + def popScope(): Unit = + stack = stack.get.surroundingScope + + def addVariable(variableName: String, variableNode: NewNode, scopeType: ScopeType): Unit = + addVariable(stack, variableName, variableNode, scopeType) + + def addVariableReference(variableName: String, referenceNode: NewNode): Unit = + pendingReferences prepend PendingReference(variableName, referenceNode, stack) + + def resolve(unresolvedHandler: (NewNode, String) => (NewNode, ScopeType)) + : Iterator[ResolvedReference] = + pendingReferences.iterator.map { pendingReference => + val resolvedReferenceOption = pendingReference.tryResolve() + + resolvedReferenceOption.getOrElse { + val methodScopeNode = Scope.getEnclosingMethodScopeNode(pendingReference.stack) + val (newVariableNode, scopeType) = + unresolvedHandler(methodScopeNode, pendingReference.variableName) + addVariable( + pendingReference.stack, + pendingReference.variableName, + newVariableNode, + scopeType + ) + pendingReference.tryResolve().get + } + } + + private def addVariable( + stack: Option[ScopeElement], + variableName: String, + variableNode: NewNode, + scopeType: ScopeType + ): Unit = + val scopeToAddTo = scopeType match + case MethodScope => Scope.getEnclosingMethodScopeElement(stack) + case _ => stack.get + scopeToAddTo.addVariable(variableName, variableNode) +end Scope + +object Scope: + private def getEnclosingMethodScopeNode(scopeHead: Option[ScopeElement]): NewNode = + getEnclosingMethodScopeElement(scopeHead).scopeNode + + def getEnclosingMethodScopeElement(scopeHead: Option[ScopeElement]): MethodScopeElement = + // There are no references outside of methods. Meaning we always find a MethodScope here. + new ScopeElementIterator(scopeHead) + .collectFirst { case methodScopeElement: MethodScopeElement => methodScopeElement } + .getOrElse(throw new RuntimeException("Cannot find method scope.")) + +class ScopeElementIterator(start: Option[ScopeElement]) extends Iterator[ScopeElement]: + private var currentScopeElement = start + + override def hasNext: Boolean = + currentScopeElement.isDefined + + override def next(): ScopeElement = + val result = currentScopeElement.get + currentScopeElement = result.surroundingScope + result diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala index 79c9b0d0..e1542cb2 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala @@ -6,13 +6,16 @@ import scala.collection.mutable /** A single element of a scope stack. */ -abstract class ScopeElement(val name: String, val scopeNode: NewNode, val surroundingScope: Option[ScopeElement]) { - var subScopeCounter: Int = 0 - val nameToVariableNode: mutable.Map[String, NewNode] = mutable.HashMap.empty +abstract class ScopeElement( + val name: String, + val scopeNode: NewNode, + val surroundingScope: Option[ScopeElement] +): + var subScopeCounter: Int = 0 + val nameToVariableNode: mutable.Map[String, NewNode] = mutable.HashMap.empty - def addVariable(variableName: String, variableNode: NewNode): Unit = - nameToVariableNode(variableName) = variableNode -} + def addVariable(variableName: String, variableNode: NewNode): Unit = + nameToVariableNode(variableName) = variableNode class MethodScopeElement( val methodFullName: String, diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala index a5815f29..c7303086 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala @@ -1,282 +1,279 @@ package io.appthreat.jssrc2cpg.parser -object BabelAst { +object BabelAst: - private val QualifiedClassName: String = BabelAst.getClass.getName + private val QualifiedClassName: String = BabelAst.getClass.getName - def fromString(nodeName: String): BabelNode = { - val clazz = Class.forName(s"$QualifiedClassName$nodeName$$") - clazz.getField("MODULE$").get(clazz).asInstanceOf[BabelNode] - } + def fromString(nodeName: String): BabelNode = + val clazz = Class.forName(s"$QualifiedClassName$nodeName$$") + clazz.getField("MODULE$").get(clazz).asInstanceOf[BabelNode] - // extracted from: - // https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts + // extracted from: + // https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts - sealed trait BabelNode { - override def toString: String = this.getClass.getSimpleName.stripSuffix("$") - } + sealed trait BabelNode: + override def toString: String = this.getClass.getSimpleName.stripSuffix("$") - sealed trait FlowType extends BabelNode + sealed trait FlowType extends BabelNode - sealed trait TSType extends BabelNode + sealed trait TSType extends BabelNode - sealed trait Expression extends BabelNode + sealed trait Expression extends BabelNode - sealed trait FunctionLike extends BabelNode + sealed trait FunctionLike extends BabelNode - object AnyTypeAnnotation extends FlowType - object ArgumentPlaceholder extends BabelNode - object ArrayExpression extends BabelNode - object ArrayPattern extends BabelNode - object ArrayTypeAnnotation extends FlowType - object ArrowFunctionExpression extends FunctionLike - object AssignmentExpression extends BabelNode - object AssignmentPattern extends BabelNode - object AwaitExpression extends BabelNode - object BigIntLiteral extends BabelNode - object BinaryExpression extends BabelNode - object BindExpression extends BabelNode - object BlockStatement extends BabelNode - object BooleanLiteral extends BabelNode - object BooleanLiteralTypeAnnotation extends FlowType - object BooleanTypeAnnotation extends FlowType - object BreakStatement extends BabelNode - object CallExpression extends Expression - object CatchClause extends BabelNode - object ClassAccessorProperty extends BabelNode - object ClassBody extends BabelNode - object ClassDeclaration extends BabelNode - object ClassExpression extends BabelNode - object ClassImplements extends BabelNode - object ClassMethod extends BabelNode - object ClassPrivateMethod extends BabelNode - object ClassPrivateProperty extends BabelNode - object ClassProperty extends BabelNode - object ConditionalExpression extends BabelNode - object ContinueStatement extends BabelNode - object DebuggerStatement extends BabelNode - object DecimalLiteral extends BabelNode - object DeclareClass extends BabelNode - object DeclareExportAllDeclaration extends BabelNode - object DeclareExportDeclaration extends BabelNode - object DeclareFunction extends BabelNode - object DeclareInterface extends BabelNode - object DeclareModule extends BabelNode - object DeclareModuleExports extends BabelNode - object DeclareOpaqueType extends BabelNode - object DeclareTypeAlias extends BabelNode - object DeclareVariable extends BabelNode - object DeclaredPredicate extends BabelNode - object Decorator extends BabelNode - object Directive extends BabelNode - object DirectiveLiteral extends BabelNode - object DoExpression extends BabelNode - object DoWhileStatement extends BabelNode - object EmptyStatement extends BabelNode - object EmptyTypeAnnotation extends FlowType - object EnumBooleanBody extends BabelNode - object EnumBooleanMember extends BabelNode - object EnumDeclaration extends BabelNode - object EnumDefaultedMember extends BabelNode - object EnumNumberBody extends BabelNode - object EnumNumberMember extends BabelNode - object EnumStringBody extends BabelNode - object EnumStringMember extends BabelNode - object EnumSymbolBody extends BabelNode - object ExistsTypeAnnotation extends FlowType - object ExportAllDeclaration extends BabelNode - object ExportDefaultDeclaration extends BabelNode - object ExportDefaultSpecifier extends BabelNode - object ExportNamedDeclaration extends BabelNode - object ExportNamespaceSpecifier extends BabelNode - object ExportSpecifier extends BabelNode - object ExpressionStatement extends BabelNode - object File extends BabelNode - object ForInStatement extends BabelNode - object ForOfStatement extends BabelNode - object ForStatement extends BabelNode - object FunctionDeclaration extends FunctionLike - object FunctionExpression extends FunctionLike - object FunctionTypeAnnotation extends FlowType - object FunctionTypeParam extends BabelNode - object GenericTypeAnnotation extends FlowType - object Identifier extends BabelNode - object IfStatement extends BabelNode - object Import extends BabelNode - object ImportAttribute extends BabelNode - object ImportDeclaration extends BabelNode - object ImportDefaultSpecifier extends BabelNode - object ImportNamespaceSpecifier extends BabelNode - object ImportSpecifier extends BabelNode - object IndexedAccessType extends FlowType - object InferredPredicate extends BabelNode - object InterfaceDeclaration extends BabelNode - object InterfaceExtends extends BabelNode - object InterfaceTypeAnnotation extends FlowType - object InterpreterDirective extends BabelNode - object IntersectionTypeAnnotation extends FlowType - object JSXAttribute extends BabelNode - object JSXClosingElement extends BabelNode - object JSXClosingFragment extends BabelNode - object JSXElement extends BabelNode - object JSXEmptyExpression extends BabelNode - object JSXExpressionContainer extends BabelNode - object JSXFragment extends BabelNode - object JSXIdentifier extends BabelNode - object JSXMemberExpression extends Expression - object JSXNamespacedName extends BabelNode - object JSXOpeningElement extends BabelNode - object JSXOpeningFragment extends BabelNode - object JSXSpreadAttribute extends BabelNode - object JSXSpreadChild extends BabelNode - object JSXText extends BabelNode - object LabeledStatement extends BabelNode - object LogicalExpression extends BabelNode - object MemberExpression extends Expression - object MetaProperty extends BabelNode - object MixedTypeAnnotation extends FlowType - object ModuleExpression extends BabelNode - object NewExpression extends Expression - object Noop extends BabelNode - object NullLiteral extends BabelNode - object NullLiteralTypeAnnotation extends FlowType - object NullableTypeAnnotation extends FlowType - object NumberLiteral extends BabelNode - object NumberLiteralTypeAnnotation extends FlowType - object NumberTypeAnnotation extends FlowType - object NumericLiteral extends BabelNode - object ObjectExpression extends BabelNode - object ObjectMethod extends BabelNode - object ObjectPattern extends BabelNode - object ObjectProperty extends BabelNode - object ObjectTypeAnnotation extends FlowType - object ObjectTypeCallProperty extends BabelNode - object ObjectTypeIndexer extends BabelNode - object ObjectTypeInternalSlot extends BabelNode - object ObjectTypeProperty extends BabelNode - object ObjectTypeSpreadProperty extends BabelNode - object OpaqueType extends BabelNode - object OptionalCallExpression extends Expression - object OptionalIndexedAccessType extends FlowType - object OptionalMemberExpression extends Expression - object ParenthesizedExpression extends BabelNode - object PipelineBareFunction extends BabelNode - object PipelinePrimaryTopicReference extends BabelNode - object PipelineTopicExpression extends Expression - object Placeholder extends BabelNode - object PrivateName extends BabelNode - object Program extends BabelNode - object QualifiedTypeIdentifier extends BabelNode - object RecordExpression extends Expression - object RegExpLiteral extends BabelNode - object RegexLiteral extends BabelNode - object RestElement extends BabelNode - object RestProperty extends BabelNode - object ReturnStatement extends BabelNode - object SequenceExpression extends BabelNode - object SpreadElement extends BabelNode - object SpreadProperty extends BabelNode - object StaticBlock extends BabelNode - object StringLiteral extends BabelNode - object StringLiteralTypeAnnotation extends FlowType - object StringTypeAnnotation extends FlowType - object Super extends BabelNode - object SwitchCase extends BabelNode - object SwitchStatement extends BabelNode - object SymbolTypeAnnotation extends FlowType - object TSAnyKeyword extends TSType - object TSArrayType extends TSType - object TSAsExpression extends Expression - object TSBigIntKeyword extends TSType - object TSBooleanKeyword extends TSType - object TSCallSignatureDeclaration extends BabelNode - object TSConditionalType extends TSType - object TSConstructSignatureDeclaration extends BabelNode - object TSConstructorType extends TSType - object TSDeclareFunction extends BabelNode - object TSDeclareMethod extends BabelNode - object TSEnumDeclaration extends BabelNode - object TSEnumMember extends BabelNode - object TSExportAssignment extends BabelNode - object TSExpressionWithTypeArguments extends TSType - object TSExternalModuleReference extends BabelNode - object TSFunctionType extends TSType - object TSImportEqualsDeclaration extends BabelNode - object TSImportType extends TSType - object TSIndexSignature extends BabelNode - object TSIndexedAccessType extends TSType - object TSInferType extends TSType - object TSInterfaceBody extends BabelNode - object TSInterfaceDeclaration extends BabelNode - object TSIntersectionType extends TSType - object TSIntrinsicKeyword extends TSType - object TSLiteralType extends TSType - object TSMappedType extends TSType - object TSMethodSignature extends BabelNode - object TSModuleBlock extends BabelNode - object TSModuleDeclaration extends BabelNode - object TSNamedTupleMember extends BabelNode - object TSNamespaceExportDeclaration extends BabelNode - object TSNeverKeyword extends TSType - object TSNonNullExpression extends Expression - object TSNullKeyword extends TSType - object TSNumberKeyword extends TSType - object TSObjectKeyword extends TSType - object TSOptionalType extends TSType - object TSParameterProperty extends Expression - object TSParenthesizedType extends TSType - object TSPropertySignature extends BabelNode - object TSQualifiedName extends BabelNode - object TSRestType extends TSType - object TSSatisfiesExpression extends Expression - object TSStringKeyword extends TSType - object TSSymbolKeyword extends TSType - object TSThisType extends TSType - object TSTupleType extends TSType - object TSTypeAliasDeclaration extends BabelNode - object TSTypeAnnotation extends FlowType - object TSTypeAssertion extends BabelNode - object TSTypeExpression extends TSType - object TSTypeLiteral extends TSType - object TSTypeOperator extends TSType - object TSTypeParameter extends TSType - object TSTypeParameterDeclaration extends BabelNode - object TSTypeParameterInstantiation extends BabelNode - object TSTypePredicate extends TSType - object TSTypeQuery extends TSType - object TSTypeReference extends TSType - object TSUndefinedKeyword extends TSType - object TSUnionType extends TSType - object TSUnknownKeyword extends TSType - object TSVoidKeyword extends TSType - object TaggedTemplateExpression extends BabelNode - object TemplateElement extends BabelNode - object TemplateLiteral extends BabelNode - object ThisExpression extends Expression - object ThisTypeAnnotation extends FlowType - object ThrowStatement extends BabelNode - object TopicReference extends BabelNode - object TryStatement extends BabelNode - object TupleExpression extends BabelNode - object TupleTypeAnnotation extends FlowType - object TypeAlias extends BabelNode - object TypeAnnotation extends FlowType - object TypeCastExpression extends BabelNode - object TSTypeCastExpression extends BabelNode - object TypeParameter extends BabelNode - object TypeParameterDeclaration extends BabelNode - object TypeParameterInstantiation extends BabelNode - object TypeofTypeAnnotation extends FlowType - object UnaryExpression extends BabelNode - object UnionTypeAnnotation extends FlowType - object UpdateExpression extends Expression - object V8IntrinsicIdentifier extends BabelNode - object VariableDeclaration extends BabelNode - object VariableDeclarator extends BabelNode - object Variance extends BabelNode - object VoidTypeAnnotation extends FlowType - object WhileStatement extends BabelNode - object WithStatement extends BabelNode - object YieldExpression extends BabelNode - -} + object AnyTypeAnnotation extends FlowType + object ArgumentPlaceholder extends BabelNode + object ArrayExpression extends BabelNode + object ArrayPattern extends BabelNode + object ArrayTypeAnnotation extends FlowType + object ArrowFunctionExpression extends FunctionLike + object AssignmentExpression extends BabelNode + object AssignmentPattern extends BabelNode + object AwaitExpression extends BabelNode + object BigIntLiteral extends BabelNode + object BinaryExpression extends BabelNode + object BindExpression extends BabelNode + object BlockStatement extends BabelNode + object BooleanLiteral extends BabelNode + object BooleanLiteralTypeAnnotation extends FlowType + object BooleanTypeAnnotation extends FlowType + object BreakStatement extends BabelNode + object CallExpression extends Expression + object CatchClause extends BabelNode + object ClassAccessorProperty extends BabelNode + object ClassBody extends BabelNode + object ClassDeclaration extends BabelNode + object ClassExpression extends BabelNode + object ClassImplements extends BabelNode + object ClassMethod extends BabelNode + object ClassPrivateMethod extends BabelNode + object ClassPrivateProperty extends BabelNode + object ClassProperty extends BabelNode + object ConditionalExpression extends BabelNode + object ContinueStatement extends BabelNode + object DebuggerStatement extends BabelNode + object DecimalLiteral extends BabelNode + object DeclareClass extends BabelNode + object DeclareExportAllDeclaration extends BabelNode + object DeclareExportDeclaration extends BabelNode + object DeclareFunction extends BabelNode + object DeclareInterface extends BabelNode + object DeclareModule extends BabelNode + object DeclareModuleExports extends BabelNode + object DeclareOpaqueType extends BabelNode + object DeclareTypeAlias extends BabelNode + object DeclareVariable extends BabelNode + object DeclaredPredicate extends BabelNode + object Decorator extends BabelNode + object Directive extends BabelNode + object DirectiveLiteral extends BabelNode + object DoExpression extends BabelNode + object DoWhileStatement extends BabelNode + object EmptyStatement extends BabelNode + object EmptyTypeAnnotation extends FlowType + object EnumBooleanBody extends BabelNode + object EnumBooleanMember extends BabelNode + object EnumDeclaration extends BabelNode + object EnumDefaultedMember extends BabelNode + object EnumNumberBody extends BabelNode + object EnumNumberMember extends BabelNode + object EnumStringBody extends BabelNode + object EnumStringMember extends BabelNode + object EnumSymbolBody extends BabelNode + object ExistsTypeAnnotation extends FlowType + object ExportAllDeclaration extends BabelNode + object ExportDefaultDeclaration extends BabelNode + object ExportDefaultSpecifier extends BabelNode + object ExportNamedDeclaration extends BabelNode + object ExportNamespaceSpecifier extends BabelNode + object ExportSpecifier extends BabelNode + object ExpressionStatement extends BabelNode + object File extends BabelNode + object ForInStatement extends BabelNode + object ForOfStatement extends BabelNode + object ForStatement extends BabelNode + object FunctionDeclaration extends FunctionLike + object FunctionExpression extends FunctionLike + object FunctionTypeAnnotation extends FlowType + object FunctionTypeParam extends BabelNode + object GenericTypeAnnotation extends FlowType + object Identifier extends BabelNode + object IfStatement extends BabelNode + object Import extends BabelNode + object ImportAttribute extends BabelNode + object ImportDeclaration extends BabelNode + object ImportDefaultSpecifier extends BabelNode + object ImportNamespaceSpecifier extends BabelNode + object ImportSpecifier extends BabelNode + object IndexedAccessType extends FlowType + object InferredPredicate extends BabelNode + object InterfaceDeclaration extends BabelNode + object InterfaceExtends extends BabelNode + object InterfaceTypeAnnotation extends FlowType + object InterpreterDirective extends BabelNode + object IntersectionTypeAnnotation extends FlowType + object JSXAttribute extends BabelNode + object JSXClosingElement extends BabelNode + object JSXClosingFragment extends BabelNode + object JSXElement extends BabelNode + object JSXEmptyExpression extends BabelNode + object JSXExpressionContainer extends BabelNode + object JSXFragment extends BabelNode + object JSXIdentifier extends BabelNode + object JSXMemberExpression extends Expression + object JSXNamespacedName extends BabelNode + object JSXOpeningElement extends BabelNode + object JSXOpeningFragment extends BabelNode + object JSXSpreadAttribute extends BabelNode + object JSXSpreadChild extends BabelNode + object JSXText extends BabelNode + object LabeledStatement extends BabelNode + object LogicalExpression extends BabelNode + object MemberExpression extends Expression + object MetaProperty extends BabelNode + object MixedTypeAnnotation extends FlowType + object ModuleExpression extends BabelNode + object NewExpression extends Expression + object Noop extends BabelNode + object NullLiteral extends BabelNode + object NullLiteralTypeAnnotation extends FlowType + object NullableTypeAnnotation extends FlowType + object NumberLiteral extends BabelNode + object NumberLiteralTypeAnnotation extends FlowType + object NumberTypeAnnotation extends FlowType + object NumericLiteral extends BabelNode + object ObjectExpression extends BabelNode + object ObjectMethod extends BabelNode + object ObjectPattern extends BabelNode + object ObjectProperty extends BabelNode + object ObjectTypeAnnotation extends FlowType + object ObjectTypeCallProperty extends BabelNode + object ObjectTypeIndexer extends BabelNode + object ObjectTypeInternalSlot extends BabelNode + object ObjectTypeProperty extends BabelNode + object ObjectTypeSpreadProperty extends BabelNode + object OpaqueType extends BabelNode + object OptionalCallExpression extends Expression + object OptionalIndexedAccessType extends FlowType + object OptionalMemberExpression extends Expression + object ParenthesizedExpression extends BabelNode + object PipelineBareFunction extends BabelNode + object PipelinePrimaryTopicReference extends BabelNode + object PipelineTopicExpression extends Expression + object Placeholder extends BabelNode + object PrivateName extends BabelNode + object Program extends BabelNode + object QualifiedTypeIdentifier extends BabelNode + object RecordExpression extends Expression + object RegExpLiteral extends BabelNode + object RegexLiteral extends BabelNode + object RestElement extends BabelNode + object RestProperty extends BabelNode + object ReturnStatement extends BabelNode + object SequenceExpression extends BabelNode + object SpreadElement extends BabelNode + object SpreadProperty extends BabelNode + object StaticBlock extends BabelNode + object StringLiteral extends BabelNode + object StringLiteralTypeAnnotation extends FlowType + object StringTypeAnnotation extends FlowType + object Super extends BabelNode + object SwitchCase extends BabelNode + object SwitchStatement extends BabelNode + object SymbolTypeAnnotation extends FlowType + object TSAnyKeyword extends TSType + object TSArrayType extends TSType + object TSAsExpression extends Expression + object TSBigIntKeyword extends TSType + object TSBooleanKeyword extends TSType + object TSCallSignatureDeclaration extends BabelNode + object TSConditionalType extends TSType + object TSConstructSignatureDeclaration extends BabelNode + object TSConstructorType extends TSType + object TSDeclareFunction extends BabelNode + object TSDeclareMethod extends BabelNode + object TSEnumDeclaration extends BabelNode + object TSEnumMember extends BabelNode + object TSExportAssignment extends BabelNode + object TSExpressionWithTypeArguments extends TSType + object TSExternalModuleReference extends BabelNode + object TSFunctionType extends TSType + object TSImportEqualsDeclaration extends BabelNode + object TSImportType extends TSType + object TSIndexSignature extends BabelNode + object TSIndexedAccessType extends TSType + object TSInferType extends TSType + object TSInterfaceBody extends BabelNode + object TSInterfaceDeclaration extends BabelNode + object TSIntersectionType extends TSType + object TSIntrinsicKeyword extends TSType + object TSLiteralType extends TSType + object TSMappedType extends TSType + object TSMethodSignature extends BabelNode + object TSModuleBlock extends BabelNode + object TSModuleDeclaration extends BabelNode + object TSNamedTupleMember extends BabelNode + object TSNamespaceExportDeclaration extends BabelNode + object TSNeverKeyword extends TSType + object TSNonNullExpression extends Expression + object TSNullKeyword extends TSType + object TSNumberKeyword extends TSType + object TSObjectKeyword extends TSType + object TSOptionalType extends TSType + object TSParameterProperty extends Expression + object TSParenthesizedType extends TSType + object TSPropertySignature extends BabelNode + object TSQualifiedName extends BabelNode + object TSRestType extends TSType + object TSSatisfiesExpression extends Expression + object TSStringKeyword extends TSType + object TSSymbolKeyword extends TSType + object TSThisType extends TSType + object TSTupleType extends TSType + object TSTypeAliasDeclaration extends BabelNode + object TSTypeAnnotation extends FlowType + object TSTypeAssertion extends BabelNode + object TSTypeExpression extends TSType + object TSTypeLiteral extends TSType + object TSTypeOperator extends TSType + object TSTypeParameter extends TSType + object TSTypeParameterDeclaration extends BabelNode + object TSTypeParameterInstantiation extends BabelNode + object TSTypePredicate extends TSType + object TSTypeQuery extends TSType + object TSTypeReference extends TSType + object TSUndefinedKeyword extends TSType + object TSUnionType extends TSType + object TSUnknownKeyword extends TSType + object TSVoidKeyword extends TSType + object TaggedTemplateExpression extends BabelNode + object TemplateElement extends BabelNode + object TemplateLiteral extends BabelNode + object ThisExpression extends Expression + object ThisTypeAnnotation extends FlowType + object ThrowStatement extends BabelNode + object TopicReference extends BabelNode + object TryStatement extends BabelNode + object TupleExpression extends BabelNode + object TupleTypeAnnotation extends FlowType + object TypeAlias extends BabelNode + object TypeAnnotation extends FlowType + object TypeCastExpression extends BabelNode + object TSTypeCastExpression extends BabelNode + object TypeParameter extends BabelNode + object TypeParameterDeclaration extends BabelNode + object TypeParameterInstantiation extends BabelNode + object TypeofTypeAnnotation extends FlowType + object UnaryExpression extends BabelNode + object UnionTypeAnnotation extends FlowType + object UpdateExpression extends Expression + object V8IntrinsicIdentifier extends BabelNode + object VariableDeclaration extends BabelNode + object VariableDeclarator extends BabelNode + object Variance extends BabelNode + object VoidTypeAnnotation extends FlowType + object WhileStatement extends BabelNode + object WithStatement extends BabelNode + object YieldExpression extends BabelNode +end BabelAst diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala index ec3117dd..5c25d36d 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala @@ -6,32 +6,29 @@ import ujson.Value.Value import java.nio.file.Path import java.nio.file.Paths -object BabelJsonParser { +object BabelJsonParser: - case class ParseResult( - filename: String, - fullPath: String, - json: Value, - fileContent: String, - typeMap: Map[Int, String] - ) + case class ParseResult( + filename: String, + fullPath: String, + json: Value, + fileContent: String, + typeMap: Map[Int, String] + ) - def readFile(rootPath: Path, file: Path): ParseResult = { - val typeMapPath = Paths.get(file.toString.replace(".json", ".typemap")) - val typeMap = if (typeMapPath.toFile.exists()) { - val typeMapJsonContent = IOUtils.readEntireFile(typeMapPath) - val typeMapJson = ujson.read(typeMapJsonContent) - typeMapJson.obj.map { case (k, v) => k.toInt -> v.str }.toMap - } else { - Map.empty[Int, String] - } + def readFile(rootPath: Path, file: Path): ParseResult = + val typeMapPath = Paths.get(file.toString.replace(".json", ".typemap")) + val typeMap = if typeMapPath.toFile.exists() then + val typeMapJsonContent = IOUtils.readEntireFile(typeMapPath) + val typeMapJson = ujson.read(typeMapJsonContent) + typeMapJson.obj.map { case (k, v) => k.toInt -> v.str }.toMap + else + Map.empty[Int, String] - val jsonContent = IOUtils.readEntireFile(file) - val json = ujson.read(jsonContent) - val filename = json("relativeName").str - val fullPath = Paths.get(rootPath.toString, filename) - val sourceFileContent = IOUtils.readEntireFile(fullPath) - ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap) - } - -} + val jsonContent = IOUtils.readEntireFile(file) + val json = ujson.read(jsonContent) + val filename = json("relativeName").str + val fullPath = Paths.get(rootPath.toString, filename) + val sourceFileContent = IOUtils.readEntireFile(fullPath) + ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap) +end BabelJsonParser diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala index c745b6f3..769127f7 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala @@ -16,47 +16,52 @@ import java.util.concurrent.ConcurrentHashMap import scala.jdk.CollectionConverters.EnumerationHasAsScala import scala.util.{Failure, Success, Try} -class AstCreationPass(cpg: Cpg, astGenRunnerResult: AstGenRunnerResult, config: Config, report: Report = new Report())( +class AstCreationPass( + cpg: Cpg, + astGenRunnerResult: AstGenRunnerResult, + config: Config, + report: Report = new Report() +)( implicit withSchemaValidation: ValidationMode -) extends ConcurrentWriterCpgPass[(String, String)](cpg) { - - private val logger: Logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - - private val usedTypes: ConcurrentHashMap[(String, String), Boolean] = new ConcurrentHashMap() - - override def generateParts(): Array[(String, String)] = astGenRunnerResult.parsedFiles.toArray - - def allUsedTypes(): List[(String, String)] = - usedTypes.keys().asScala.filterNot { case (typeName, _) => typeName == Defines.Any }.toList - - override def finish(): Unit = { - astGenRunnerResult.skippedFiles.foreach { skippedFile => - val (rootPath, fileName) = skippedFile - val filePath = Paths.get(rootPath, fileName) - val fileLOC = IOUtils.readLinesInFile(filePath).size - report.addReportInfo(fileName, fileLOC) - } - } - - override def runOnPart(diffGraph: DiffGraphBuilder, input: (String, String)): Unit = { - val (rootPath, jsonFilename) = input - val ((gotCpg, filename), duration) = TimeUtils.time { - val parseResult = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) - val fileLOC = IOUtils.readLinesInFile(Paths.get(parseResult.fullPath)).size - report.addReportInfo(parseResult.filename, fileLOC, parsed = true) - Try { - val localDiff = new AstCreator(config, parseResult, usedTypes).createAst() - diffGraph.absorb(localDiff) - } match { - case Failure(exception) => - logger.warn(s"Failed to generate a CPG for: '${parseResult.fullPath}'", exception) - (false, parseResult.filename) - case Success(_) => - logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") - (true, parseResult.filename) - } - } - report.updateReport(filename, cpg = gotCpg, duration) - } - -} +) extends ConcurrentWriterCpgPass[(String, String)](cpg): + + private val logger: Logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + + private val usedTypes: ConcurrentHashMap[(String, String), Boolean] = new ConcurrentHashMap() + + override def generateParts(): Array[(String, String)] = astGenRunnerResult.parsedFiles.toArray + + def allUsedTypes(): List[(String, String)] = + usedTypes.keys().asScala.filterNot { case (typeName, _) => typeName == Defines.Any }.toList + + override def finish(): Unit = + astGenRunnerResult.skippedFiles.foreach { skippedFile => + val (rootPath, fileName) = skippedFile + val filePath = Paths.get(rootPath, fileName) + val fileLOC = IOUtils.readLinesInFile(filePath).size + report.addReportInfo(fileName, fileLOC) + } + + override def runOnPart(diffGraph: DiffGraphBuilder, input: (String, String)): Unit = + val (rootPath, jsonFilename) = input + val ((gotCpg, filename), duration) = TimeUtils.time { + val parseResult = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) + val fileLOC = IOUtils.readLinesInFile(Paths.get(parseResult.fullPath)).size + report.addReportInfo(parseResult.filename, fileLOC, parsed = true) + Try { + val localDiff = new AstCreator(config, parseResult, usedTypes).createAst() + diffGraph.absorb(localDiff) + } match + case Failure(exception) => + logger.warn( + s"Failed to generate a CPG for: '${parseResult.fullPath}'", + exception + ) + (false, parseResult.filename) + case Success(_) => + logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") + (true, parseResult.filename) + } + report.updateReport(filename, cpg = gotCpg, duration) + end runOnPart +end AstCreationPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala index de9628ae..951ed10b 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala @@ -5,36 +5,35 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewNamespaceBlock, NewTyp import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.passes.CpgPass -class BuiltinTypesPass(cpg: Cpg) extends CpgPass(cpg) { - - override def run(diffGraph: DiffGraphBuilder): Unit = { - val namespaceBlock = NewNamespaceBlock() - .name(Defines.GlobalNamespace) - .fullName(Defines.GlobalNamespace) - .order(0) - .filename("builtintypes") - - diffGraph.addNode(namespaceBlock) - - Defines.JsTypes.zipWithIndex.map { case (typeName: String, index) => - val tpe = NewType() - .name(typeName) - .fullName(typeName) - .typeDeclFullName(typeName) - diffGraph.addNode(tpe) - - val typeDecl = NewTypeDecl() - .name(typeName) - .fullName(typeName) - .isExternal(true) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(Defines.GlobalNamespace) - .order(index + 1) - .filename("builtintypes") - - diffGraph.addNode(typeDecl) - diffGraph.addEdge(namespaceBlock, typeDecl, EdgeTypes.AST) - } - } - -} +class BuiltinTypesPass(cpg: Cpg) extends CpgPass(cpg): + + override def run(diffGraph: DiffGraphBuilder): Unit = + val namespaceBlock = NewNamespaceBlock() + .name(Defines.GlobalNamespace) + .fullName(Defines.GlobalNamespace) + .order(0) + .filename("builtintypes") + + diffGraph.addNode(namespaceBlock) + + Defines.JsTypes.zipWithIndex.map { case (typeName: String, index) => + val tpe = NewType() + .name(typeName) + .fullName(typeName) + .typeDeclFullName(typeName) + diffGraph.addNode(tpe) + + val typeDecl = NewTypeDecl() + .name(typeName) + .fullName(typeName) + .isExternal(true) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(Defines.GlobalNamespace) + .order(index + 1) + .filename("builtintypes") + + diffGraph.addNode(typeDecl) + diffGraph.addEdge(namespaceBlock, typeDecl, EdgeTypes.AST) + } + end run +end BuiltinTypesPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala index 92744769..5a90a26f 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala @@ -10,40 +10,40 @@ import io.shiftleft.passes.ConcurrentWriterCpgPass import io.shiftleft.utils.IOUtils import org.slf4j.{Logger, LoggerFactory} -class ConfigPass(cpg: Cpg, config: Config, report: Report = new Report()) extends ConcurrentWriterCpgPass[File](cpg) { - - private val logger: Logger = LoggerFactory.getLogger(getClass) - - protected val allExtensions: Set[String] = Set(".json", ".js", ".vue", ".html", ".pug") - protected val selectedExtensions: Set[String] = Set(".json", ".config.js", ".conf.js", ".vue", ".html", ".pug") - - override def generateParts(): Array[File] = - configFiles(config, allExtensions).toArray - - protected def fileContent(file: File): Seq[String] = - IOUtils.readLinesInFile(file.path) - - protected def configFiles(config: Config, extensions: Set[String]): Seq[File] = - SourceFiles - .determine(config.inputPath, extensions) - .filterNot(_.contains(Defines.NodeModulesFolder)) - .filter(f => selectedExtensions.exists(f.endsWith)) - .map(File(_)) - - override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = { - val path = File(config.inputPath).path.toAbsolutePath.relativize(file.path).toString - logger.debug(s"Adding file '$path' as config.") - val (gotCpg, duration) = TimeUtils.time { - val localDiff = new DiffGraphBuilder - val content = fileContent(file) - val loc = content.size - val configNode = NewConfigFile().name(path).content(content.mkString("\n")) - report.addReportInfo(path, loc, parsed = true) - localDiff.addNode(configNode) - localDiff - } - diffGraph.absorb(gotCpg) - report.updateReport(path, cpg = true, duration) - } - -} +class ConfigPass(cpg: Cpg, config: Config, report: Report = new Report()) + extends ConcurrentWriterCpgPass[File](cpg): + + private val logger: Logger = LoggerFactory.getLogger(getClass) + + protected val allExtensions: Set[String] = Set(".json", ".js", ".vue", ".html", ".pug") + protected val selectedExtensions: Set[String] = + Set(".json", ".config.js", ".conf.js", ".vue", ".html", ".pug") + + override def generateParts(): Array[File] = + configFiles(config, allExtensions).toArray + + protected def fileContent(file: File): Seq[String] = + IOUtils.readLinesInFile(file.path) + + protected def configFiles(config: Config, extensions: Set[String]): Seq[File] = + SourceFiles + .determine(config.inputPath, extensions) + .filterNot(_.contains(Defines.NodeModulesFolder)) + .filter(f => selectedExtensions.exists(f.endsWith)) + .map(File(_)) + + override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = + val path = File(config.inputPath).path.toAbsolutePath.relativize(file.path).toString + logger.debug(s"Adding file '$path' as config.") + val (gotCpg, duration) = TimeUtils.time { + val localDiff = new DiffGraphBuilder + val content = fileContent(file) + val loc = content.size + val configNode = NewConfigFile().name(path).content(content.mkString("\n")) + report.addReportInfo(path, loc, parsed = true) + localDiff.addNode(configNode) + localDiff + } + diffGraph.absorb(gotCpg) + report.updateReport(path, cpg = true, duration) +end ConfigPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala index 55c8aa68..66d66332 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala @@ -4,68 +4,66 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Method, MethodRef} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -/** A pass that identifies assignments of closures to constants and updates `METHOD` nodes accordingly. +/** A pass that identifies assignments of closures to constants and updates `METHOD` nodes + * accordingly. */ -class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg) { +class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg): - // Keeps track of how many times an identifier has been on the LHS of an assignment, by name - private lazy val identifiersAssignedCount: Map[String, Int] = - cpg.assignment.target.collectAll[Identifier].name.groupCount + // Keeps track of how many times an identifier has been on the LHS of an assignment, by name + private lazy val identifiersAssignedCount: Map[String, Int] = + cpg.assignment.target.collectAll[Identifier].name.groupCount - override def run(diffGraph: DiffGraphBuilder): Unit = { - handleConstClosures(diffGraph) - handleClosuresDefinedAtExport(diffGraph) - handleClosuresAssignedToMutableVar(diffGraph) - } + override def run(diffGraph: DiffGraphBuilder): Unit = + handleConstClosures(diffGraph) + handleClosuresDefinedAtExport(diffGraph) + handleClosuresAssignedToMutableVar(diffGraph) - private def handleConstClosures(diffGraph: DiffGraphBuilder): Unit = - for { - assignment <- cpg.assignment - name <- assignment.filter(_.code.startsWith("const ")).target.isIdentifier.name - methodRef <- assignment.start.source.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - } { - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - } + private def handleConstClosures(diffGraph: DiffGraphBuilder): Unit = + for + assignment <- cpg.assignment + name <- assignment.filter(_.code.startsWith("const ")).target.isIdentifier.name + methodRef <- assignment.start.source.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def handleClosuresDefinedAtExport(diffGraph: DiffGraphBuilder): Unit = - for { - assignment <- cpg.assignment - name <- assignment.filter(_.code.startsWith("export")).target.isCall.argument.isFieldIdentifier.canonicalName.l - methodRef <- assignment.start.source.ast.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - } { - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - } + private def handleClosuresDefinedAtExport(diffGraph: DiffGraphBuilder): Unit = + for + assignment <- cpg.assignment + name <- assignment.filter( + _.code.startsWith("export") + ).target.isCall.argument.isFieldIdentifier.canonicalName.l + methodRef <- assignment.start.source.ast.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def handleClosuresAssignedToMutableVar(diffGraph: DiffGraphBuilder): Unit = - // Handle closures assigned to mutable variables - for { - assignment <- cpg.assignment - name <- assignment.start.code("^(var|let) .*").target.isIdentifier.name - methodRef <- assignment.start.source.ast.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - } { - // Conservatively update closures, i.e, if we only find 1 assignment where this variable is on the LHS - if (identifiersAssignedCount.getOrElse(name, -1) == 1) - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - } + private def handleClosuresAssignedToMutableVar(diffGraph: DiffGraphBuilder): Unit = + // Handle closures assigned to mutable variables + for + assignment <- cpg.assignment + name <- assignment.start.code("^(var|let) .*").target.isIdentifier.name + methodRef <- assignment.start.source.ast.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + // Conservatively update closures, i.e, if we only find 1 assignment where this variable is on the LHS + if identifiersAssignedCount.getOrElse(name, -1) == 1 then + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def updateClosures( - diffGraph: DiffGraphBuilder, - method: Method, - methodRef: MethodRef, - enclosingMethod: String, - name: String - ): Unit = { - val fullName = s"$enclosingMethod:$name" - diffGraph.setNodeProperty(methodRef, PropertyNames.METHOD_FULL_NAME, fullName) - diffGraph.setNodeProperty(method, PropertyNames.NAME, name) - diffGraph.setNodeProperty(method, PropertyNames.FULL_NAME, fullName) - } -} + private def updateClosures( + diffGraph: DiffGraphBuilder, + method: Method, + methodRef: MethodRef, + enclosingMethod: String, + name: String + ): Unit = + val fullName = s"$enclosingMethod:$name" + diffGraph.setNodeProperty(methodRef, PropertyNames.METHOD_FULL_NAME, fullName) + diffGraph.setNodeProperty(method, PropertyNames.NAME, name) + diffGraph.setNodeProperty(method, PropertyNames.FULL_NAME, fullName) +end ConstClosurePass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala index b5c10692..bd24e791 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala @@ -2,24 +2,39 @@ package io.appthreat.jssrc2cpg.passes import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal -object Defines { - val Any: String = "ANY" - val Number: String = "__ecma.Number" - val String: String = "__ecma.String" - val Boolean: String = "__ecma.Boolean" - val Null: String = "__ecma.Null" - val Math: String = "__ecma.Math" - val Symbol: String = "__ecma.Symbol" - val Console: String = "__whatwg.console" - val Object: String = "object" - val BigInt: String = "bigint" - val Unknown: String = "unknown" - val Void: String = "void" - val Never: String = "never" - val Undefined: String = "undefined" - val NodeModulesFolder: String = "node_modules" - val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName +object Defines: + val Any: String = "ANY" + val Number: String = "__ecma.Number" + val String: String = "__ecma.String" + val Boolean: String = "__ecma.Boolean" + val Null: String = "__ecma.Null" + val Math: String = "__ecma.Math" + val Symbol: String = "__ecma.Symbol" + val Console: String = "__whatwg.console" + val Object: String = "object" + val BigInt: String = "bigint" + val Unknown: String = "unknown" + val Void: String = "void" + val Never: String = "never" + val Undefined: String = "undefined" + val NodeModulesFolder: String = "node_modules" + val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName - val JsTypes: List[String] = - List(Any, Number, String, Boolean, Null, Math, Symbol, Console, Object, BigInt, Unknown, Never, Void, Undefined) -} + val JsTypes: List[String] = + List( + Any, + Number, + String, + Boolean, + Null, + Math, + Symbol, + Console, + Object, + BigInt, + Unknown, + Never, + Void, + Undefined + ) +end Defines diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala index a52fd7c9..cd0eabb8 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala @@ -11,25 +11,26 @@ import java.nio.file.Paths /** Creation of DEPENDENCY nodes from "package.json" files. */ -class DependenciesPass(cpg: Cpg, config: Config) extends CpgPass(cpg) { +class DependenciesPass(cpg: Cpg, config: Config) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = { - val packagesJsons = SourceFiles - .determine(config.inputPath, Set(".json")) - .filterNot(_.contains(Defines.NodeModulesFolder)) - .filter(f => - f.endsWith(PackageJsonParser.PackageJsonFilename) || f.endsWith(PackageJsonParser.PackageJsonLockFilename) - ) + override def run(diffGraph: DiffGraphBuilder): Unit = + val packagesJsons = SourceFiles + .determine(config.inputPath, Set(".json")) + .filterNot(_.contains(Defines.NodeModulesFolder)) + .filter(f => + f.endsWith(PackageJsonParser.PackageJsonFilename) || f.endsWith( + PackageJsonParser.PackageJsonLockFilename + ) + ) - val dependencies: Map[String, String] = - packagesJsons.flatMap(p => PackageJsonParser.dependencies(Paths.get(p))).toMap + val dependencies: Map[String, String] = + packagesJsons.flatMap(p => PackageJsonParser.dependencies(Paths.get(p))).toMap - dependencies.foreach { case (name, version) => - val dep = NewDependency() - .name(name) - .version(version) - diffGraph.addNode(dep) - } - } - -} + dependencies.foreach { case (name, version) => + val dep = NewDependency() + .name(name) + .version(version) + diffGraph.addNode(dep) + } + end run +end DependenciesPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala index 644e65f5..7956fa95 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala @@ -1,5 +1,4 @@ package io.appthreat.jssrc2cpg.passes -object EcmaBuiltins { - val arrayFactory = "__ecma.Array.factory" -} +object EcmaBuiltins: + val arrayFactory = "__ecma.Array.factory" diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala index 58a3f21f..ace5c6e2 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala @@ -1,1094 +1,1093 @@ package io.appthreat.jssrc2cpg.passes -object GlobalBuiltins { +object GlobalBuiltins: - val builtins: Set[String] = Set( - "AggregateError", - "Array", - "ArrayBuffer", - "AsyncFunction", - "AsyncGenerator", - "AsyncGeneratorFunction", - "AsyncIterator", - "Atomics", - "BigInt", - "BigInt64Array", - "BigUint64Array", - "Boolean", - "Buffer.from", - "DataView", - "Date", - "Error", - "EvalError", - "FinalizationRegistry", - "Float32Array", - "Float64Array", - "Function", - "Generator", - "GeneratorFunction", - "HTMLImageElement", - "Iterator", - "Infinity", - "Int16Array", - "Int32Array", - "Int8Array", - "InternalError", - "Intl", - "Intl.Collator", - "Intl.DateTimeFormat", - "Intl.DisplayNames", - "Intl.DurationFormat", - "Intl.ListFormat", - "Intl.Locale", - "Intl.NumberFormat", - "Intl.PluralRules", - "Intl.RelativeTimeFormat", - "Intl.Segmenter", - "JSON", - "JSON.parse", - "JSON.stringify", - "Map", - "Math", - "NaN", - "Number", - "Number.isFinite", - "Number.isInteger", - "Number.isNaN", - "Number.isSafeInteger", - "Number.parseFloat", - "Number.parseInt", - "Number.prototype.toExponential", - "Number.prototype.toFixed", - "Number.prototype.toLocaleString", - "Number.prototype.toPrecision", - "Number.prototype.toSource", - "Number.prototype.toString", - "Number.prototype.valueOf", - "Object", - "Object.assign", - "Object.create", - "Object.defineProperties", - "Object.defineProperty", - "Object.entries", - "Object.freeze", - "Object.fromEntries", - "Object.getOwnPropertyDescriptor", - "Object.getOwnPropertyDescriptors", - "Object.getOwnPropertyNames", - "Object.getOwnPropertySymbols", - "Object.getPrototypeOf", - "Object.is", - "Object.isExtensible", - "Object.isFrozen", - "Object.isSealed", - "Object.keys", - "Object.preventExtensions", - "Object.prototype.__defineGetter__", - "Object.prototype.__defineSetter__", - "Object.prototype.__lookupGetter__", - "Object.prototype.__lookupSetter__", - "Object.prototype.hasOwnProperty", - "Object.prototype.isPrototypeOf", - "Object.prototype.propertyIsEnumerable", - "Object.prototype.toLocaleString", - "Object.prototype.toSource", - "Object.prototype.toString", - "Object.prototype.valueOf", - "Object.seal", - "Object.setPrototypeOf", - "Object.values", - "Promise", - "Promise.all", - "Promise.allSettled", - "Promise.any", - "Promise.race", - "Promise.reject", - "Promise.resolve", - "Proxy", - "RangeError", - "ReferenceError", - "Reflect", - "RegExp", - "Set", - "SharedArrayBuffer", - "String", - "Symbol", - "SyntaxError", - "TypeError", - "TypedArray", - "URIError", - "Uint16Array", - "Uint32Array", - "Uint8Array", - "Uint8ClampedArray", - "WeakMap", - "WeakRef", - "WeakSet", - "decodeURI", - "decodeURIComponent", - "encodeURI", - "encodeURIComponent", - "escape", - "eval", - "eval", - "fetch", - "globalThis", - "isFinite", - "isNaN", - "localStorage.setItem", - "parseFloat", - "parseInt", - "undefined", - "unescape", - "uneval", - "AbortController", - "AbortSignal", - "AbsoluteOrientationSensor", - "AbstractRange", - "Accelerometer", - "AesCbcParams", - "AesCtrParams", - "AesGcmParams", - "AesKeyGenParams", - "AmbientLightSensor", - "AnalyserNode", - "ANGLE_instanced_arrays", - "Animation", - "AnimationEffect", - "AnimationEvent", - "AnimationPlaybackEvent", - "AnimationTimeline", - "Attr", - "AudioBuffer", - "AudioBufferSourceNode", - "AudioContext", - "AudioData", - "AudioDecoder", - "AudioDestinationNode", - "AudioEncoder", - "AudioListener", - "AudioNode", - "AudioParam", - "AudioParamDescriptor", - "AudioParamMap", - "AudioProcessingEvent", - "AudioScheduledSourceNode", - "AudioSinkInfo", - "AudioTrack", - "AudioTrackList", - "AudioWorklet", - "AudioWorkletGlobalScope", - "AudioWorkletNode", - "AudioWorkletProcessor", - "AuthenticatorAssertionResponse", - "AuthenticatorAttestationResponse", - "AuthenticatorResponse", - "BackgroundFetchEvent", - "BackgroundFetchManager", - "BackgroundFetchRecord", - "BackgroundFetchRegistration", - "BackgroundFetchUpdateUIEvent", - "BarcodeDetector", - "BarProp", - "BaseAudioContext", - "BatteryManager", - "BeforeInstallPromptEvent", - "BeforeUnloadEvent", - "BiquadFilterNode", - "Blob", - "BlobEvent", - "Bluetooth", - "BluetoothCharacteristicProperties", - "BluetoothDevice", - "BluetoothRemoteGATTCharacteristic", - "BluetoothRemoteGATTDescriptor", - "BluetoothRemoteGATTServer", - "BluetoothRemoteGATTService", - "BluetoothUUID", - "BroadcastChannel", - "ByteLengthQueuingStrategy", - "Cache", - "CacheStorage", - "CanMakePaymentEvent", - "CanvasCaptureMediaStreamTrack", - "CanvasGradient", - "CanvasPattern", - "CanvasRenderingContext2D", - "CaptureController", - "CaretPosition", - "CDATASection", - "ChannelMergerNode", - "ChannelSplitterNode", - "CharacterData", - "Client", - "Clients", - "Clipboard", - "ClipboardEvent", - "ClipboardItem", - "CloseEvent", - "Comment", - "CompositionEvent", - "CompressionStream", - "console", - "ConstantSourceNode", - "ContactAddress", - "ContactsManager", - "ContentIndex", - "ContentIndexEvent", - "ContentVisibilityAutoStateChangeEvent", - "ConvolverNode", - "CookieChangeEvent", - "CookieStore", - "CookieStoreManager", - "CountQueuingStrategy", - "Credential", - "CredentialsContainer", - "Crypto", - "CryptoKey", - "CryptoKeyPair", - "CSPViolationReportBody", - "CSS", - "CSSAnimation", - "CSSConditionRule", - "CSSContainerRule", - "CSSCounterStyleRule", - "CSSFontFaceRule", - "CSSFontFeatureValuesRule", - "CSSFontPaletteValuesRule", - "CSSGroupingRule", - "CSSImageValue", - "CSSImportRule", - "CSSKeyframeRule", - "CSSKeyframesRule", - "CSSKeywordValue", - "CSSLayerBlockRule", - "CSSLayerStatementRule", - "CSSMathInvert", - "CSSMathMax", - "CSSMathMin", - "CSSMathNegate", - "CSSMathProduct", - "CSSMathSum", - "CSSMathValue", - "CSSMatrixComponent", - "CSSMediaRule", - "CSSNamespaceRule", - "CSSNumericArray", - "CSSNumericValue", - "CSSPageRule", - "CSSPerspective", - "CSSPositionValue", - "CSSPrimitiveValue", - "CSSPropertyRule", - "CSSPseudoElement", - "CSSRotate", - "CSSRule", - "CSSRuleList", - "CSSScale", - "CSSSkew", - "CSSSkewX", - "CSSSkewY", - "CSSStyleDeclaration", - "CSSStyleRule", - "CSSStyleSheet", - "CSSStyleValue", - "CSSSupportsRule", - "CSSTransformComponent", - "CSSTransformValue", - "CSSTransition", - "CSSTranslate", - "CSSUnitValue", - "CSSUnparsedValue", - "CSSValue", - "CSSValueList", - "CSSVariableReferenceValue", - "CustomElementRegistry", - "CustomEvent", - "CustomStateSet", - "DataTransfer", - "DataTransferItem", - "DataTransferItemList", - "DecompressionStream", - "DedicatedWorkerGlobalScope", - "DelayNode", - "DeprecationReportBody", - "DeviceMotionEvent", - "DeviceMotionEventAcceleration", - "DeviceMotionEventRotationRate", - "DeviceOrientationEvent", - "DirectoryEntrySync", - "DirectoryReaderSync", - "Document", - "DocumentFragment", - "DocumentTimeline", - "DocumentType", - "DOMError", - "DOMException", - "DOMHighResTimeStamp", - "DOMImplementation", - "DOMMatrix (WebKitCSSMatrix)", - "DOMMatrixReadOnly", - "DOMParser", - "DOMPoint", - "DOMPointReadOnly", - "DOMQuad", - "DOMRect", - "DOMRectReadOnly", - "DOMStringList", - "DOMStringMap", - "DOMTokenList", - "DragEvent", - "DynamicsCompressorNode", - "EcdhKeyDeriveParams", - "EcdsaParams", - "EcKeyGenParams", - "EcKeyImportParams", - "Element", - "ElementInternals", - "EncodedAudioChunk", - "EncodedVideoChunk", - "ErrorEvent", - "Event", - "EventCounts", - "EventSource", - "EventTarget", - "ExtendableCookieChangeEvent", - "ExtendableEvent", - "ExtendableMessageEvent", - "EyeDropper", - "FeaturePolicy", - "FederatedCredential", - "FetchEvent", - "File", - "FileEntrySync", - "FileList", - "FileReader", - "FileReaderSync", - "FileSystem", - "FileSystemDirectoryEntry", - "FileSystemDirectoryHandle", - "FileSystemDirectoryReader", - "FileSystemEntry", - "FileSystemFileEntry", - "FileSystemFileHandle", - "FileSystemHandle", - "FileSystemSync", - "FileSystemSyncAccessHandle", - "FileSystemWritableFileStream", - "FocusEvent", - "FontData", - "FontFace", - "FontFaceSet", - "FontFaceSetLoadEvent", - "FormData", - "FormDataEvent", - "FragmentDirective", - "GainNode", - "Gamepad", - "GamepadButton", - "GamepadEvent", - "GamepadHapticActuator", - "GamepadPose", - "Geolocation", - "GeolocationCoordinates", - "GeolocationPosition", - "GeolocationPositionError", - "GestureEvent", - "GPU", - "GPUAdapter", - "GPUAdapterInfo", - "GPUBindGroup", - "GPUBindGroupLayout", - "GPUBuffer", - "GPUCanvasContext", - "GPUCommandBuffer", - "GPUCommandEncoder", - "GPUCompilationInfo", - "GPUCompilationMessage", - "GPUComputePassEncoder", - "GPUComputePipeline", - "GPUDevice", - "GPUDeviceLostInfo", - "GPUError", - "GPUExternalTexture", - "GPUInternalError", - "GPUOutOfMemoryError", - "GPUPipelineError", - "GPUPipelineLayout", - "GPUQuerySet", - "GPUQueue", - "GPURenderBundle", - "GPURenderBundleEncoder", - "GPURenderPassEncoder", - "GPURenderPipeline", - "GPUSampler", - "GPUShaderModule", - "GPUSupportedFeatures", - "GPUSupportedLimits", - "GPUTexture", - "GPUTextureView", - "GPUUncapturedErrorEvent", - "GPUValidationError", - "GravitySensor", - "Gyroscope", - "HashChangeEvent", - "Headers", - "HID", - "HIDConnectionEvent", - "HIDDevice", - "HIDInputReportEvent", - "Highlight", - "HighlightRegistry", - "History", - "HkdfParams", - "HmacImportParams", - "HmacKeyGenParams", - "HMDVRDevice", - "HTMLAnchorElement", - "HTMLAreaElement", - "HTMLAudioElement", - "HTMLBaseElement", - "HTMLBodyElement", - "HTMLBRElement", - "HTMLButtonElement", - "HTMLCanvasElement", - "HTMLCollection", - "HTMLDataElement", - "HTMLDataListElement", - "HTMLDetailsElement", - "HTMLDialogElement", - "HTMLDivElement", - "HTMLDListElement", - "HTMLDocument", - "HTMLElement", - "HTMLEmbedElement", - "HTMLFieldSetElement", - "HTMLFontElement", - "HTMLFormControlsCollection", - "HTMLFormElement", - "HTMLFrameSetElement", - "HTMLHeadElement", - "HTMLHeadingElement", - "HTMLHRElement", - "HTMLHtmlElement", - "HTMLIFrameElement", - "HTMLImageElement", - "HTMLInputElement", - "HTMLLabelElement", - "HTMLLegendElement", - "HTMLLIElement", - "HTMLLinkElement", - "HTMLMapElement", - "HTMLMarqueeElement", - "HTMLMediaElement", - "HTMLMenuElement", - "HTMLMenuItemElement", - "HTMLMetaElement", - "HTMLMeterElement", - "HTMLModElement", - "HTMLObjectElement", - "HTMLOListElement", - "HTMLOptGroupElement", - "HTMLOptionElement", - "HTMLOptionsCollection", - "HTMLOutputElement", - "HTMLParagraphElement", - "HTMLParamElement", - "HTMLPictureElement", - "HTMLPreElement", - "HTMLProgressElement", - "HTMLQuoteElement", - "HTMLScriptElement", - "HTMLSelectElement", - "HTMLSlotElement", - "HTMLSourceElement", - "HTMLSpanElement", - "HTMLStyleElement", - "HTMLTableCaptionElement", - "HTMLTableCellElement", - "HTMLTableColElement", - "HTMLTableElement", - "HTMLTableRowElement", - "HTMLTableSectionElement", - "HTMLTemplateElement", - "HTMLTextAreaElement", - "HTMLTimeElement", - "HTMLTitleElement", - "HTMLTrackElement", - "HTMLUListElement", - "HTMLUnknownElement", - "HTMLVideoElement", - "IDBCursor", - "IDBCursorWithValue", - "IDBDatabase", - "IDBFactory", - "IDBIndex", - "IDBKeyRange", - "IDBLocaleAwareKeyRange", - "IDBObjectStore", - "IDBOpenDBRequest", - "IDBRequest", - "IDBTransaction", - "IDBVersionChangeEvent", - "IdentityCredential", - "IdleDeadline", - "IdleDetector", - "IIRFilterNode", - "ImageBitmap", - "ImageBitmapRenderingContext", - "ImageCapture", - "ImageData", - "ImageDecoder", - "ImageTrack", - "ImageTrackList", - "Ink", - "InkPresenter", - "InputDeviceCapabilities", - "InputDeviceInfo", - "InputEvent", - "InstallEvent", - "IntersectionObserver", - "IntersectionObserverEntry", - "InterventionReportBody", - "Keyboard", - "KeyboardEvent", - "KeyboardLayoutMap", - "KeyframeEffect", - "LargestContentfulPaint", - "LaunchParams", - "LaunchQueue", - "LayoutShift", - "LayoutShiftAttribution", - "LinearAccelerationSensor", - "Location", - "Lock", - "LockManager", - "Magnetometer", - "MathMLElement", - "MediaCapabilities", - "MediaDeviceInfo", - "MediaDevices", - "MediaElementAudioSourceNode", - "MediaEncryptedEvent", - "MediaError", - "MediaImage", - "MediaKeyMessageEvent", - "MediaKeys", - "MediaKeySession", - "MediaKeyStatusMap", - "MediaKeySystemAccess", - "MediaList", - "MediaMetadata", - "MediaQueryList", - "MediaQueryListEvent", - "MediaRecorder", - "MediaRecorderErrorEvent", - "MediaSession", - "MediaSource", - "MediaSourceHandle", - "MediaStream", - "MediaStreamAudioDestinationNode", - "MediaStreamAudioSourceNode", - "MediaStreamEvent", - "MediaStreamTrack", - "MediaStreamTrackAudioSourceNode", - "MediaStreamTrackEvent", - "MediaStreamTrackGenerator", - "MediaStreamTrackProcessor", - "MediaTrackConstraints", - "MediaTrackSettings", - "MediaTrackSupportedConstraints", - "MerchantValidationEvent", - "MessageChannel", - "MessageEvent", - "MessagePort", - "Metadata", - "MIDIAccess", - "MIDIConnectionEvent", - "MIDIInput", - "MIDIInputMap", - "MIDIMessageEvent", - "MIDIOutput", - "MIDIOutputMap", - "MIDIPort", - "MimeType", - "MimeTypeArray", - "MouseEvent", - "MouseScrollEvent", - "MutationEvent", - "MutationObserver", - "MutationRecord", - "NamedNodeMap", - "NavigateEvent", - "Navigation", - "NavigationCurrentEntryChangeEvent", - "NavigationDestination", - "NavigationHistoryEntry", - "NavigationPreloadManager", - "NavigationTransition", - "Navigator", - "NavigatorUAData", - "NDEFMessage", - "NDEFReader", - "NDEFReadingEvent", - "NDEFRecord", - "NetworkInformation", - "Node", - "NodeIterator", - "NodeList", - "Notification", - "NotificationEvent", - "NotifyAudioAvailableEvent", - "OES_draw_buffers_indexed", - "OfflineAudioCompletionEvent", - "OfflineAudioContext", - "OffscreenCanvas", - "OffscreenCanvasRenderingContext2D", - "OrientationSensor", - "OscillatorNode", - "OTPCredential", - "OverconstrainedError", - "PageTransitionEvent", - "PaintWorkletGlobalScope", - "PannerNode", - "PasswordCredential", - "Path2D", - "PaymentAddress", - "PaymentManager", - "PaymentMethodChangeEvent", - "PaymentRequest", - "PaymentRequestEvent", - "PaymentRequestUpdateEvent", - "PaymentResponse", - "Pbkdf2Params", - "Performance", - "PerformanceElementTiming", - "PerformanceEntry", - "PerformanceEventTiming", - "PerformanceLongTaskTiming", - "PerformanceMark", - "PerformanceMeasure", - "PerformanceNavigation", - "PerformanceNavigationTiming", - "PerformanceObserver", - "PerformanceObserverEntryList", - "PerformancePaintTiming", - "PerformanceResourceTiming", - "PerformanceServerTiming", - "PerformanceTiming", - "PeriodicSyncEvent", - "PeriodicSyncManager", - "PeriodicWave", - "Permissions", - "PermissionStatus", - "PictureInPictureEvent", - "PictureInPictureWindow", - "Plugin", - "PluginArray", - "Point", - "PointerEvent", - "PopStateEvent", - "PositionSensorVRDevice", - "Presentation", - "PresentationAvailability", - "PresentationConnection", - "PresentationConnectionAvailableEvent", - "PresentationConnectionCloseEvent", - "PresentationConnectionList", - "PresentationReceiver", - "PresentationRequest", - "ProcessingInstruction", - "ProgressEvent", - "PromiseRejectionEvent", - "PublicKeyCredential", - "PushEvent", - "PushManager", - "PushMessageData", - "PushSubscription", - "PushSubscriptionOptions", - "RadioNodeList", - "Range", - "ReadableByteStreamController", - "ReadableStream", - "ReadableStreamBYOBReader", - "ReadableStreamBYOBRequest", - "ReadableStreamDefaultController", - "ReadableStreamDefaultReader", - "RelativeOrientationSensor", - "RemotePlayback", - "Report", - "ReportBody", - "ReportingObserver", - "Request", - "ResizeObserver", - "ResizeObserverEntry", - "ResizeObserverSize", - "Response", - "RsaHashedImportParams", - "RsaHashedKeyGenParams", - "RsaOaepParams", - "RsaPssParams", - "RTCAudioSourceStats", - "RTCCertificate", - "RTCDataChannel", - "RTCDataChannelEvent", - "RTCDtlsTransport", - "RTCDTMFSender", - "RTCDTMFToneChangeEvent", - "RTCError", - "RTCErrorEvent", - "RTCIceCandidate", - "RTCIceCandidatePair", - "RTCIceCandidatePairStats", - "RTCIceCandidateStats", - "RTCIceParameters", - "RTCIceServer", - "RTCIceTransport", - "RTCIdentityAssertion", - "RTCInboundRtpStreamStats", - "RTCOutboundRtpStreamStats", - "RTCPeerConnection", - "RTCPeerConnectionIceErrorEvent", - "RTCPeerConnectionIceEvent", - "RTCPeerConnectionStats", - "RTCRemoteOutboundRtpStreamStats", - "RTCRtpCodecParameters", - "RTCRtpContributingSource", - "RTCRtpEncodingParameters", - "RTCRtpReceiver", - "RTCRtpSender", - "RTCRtpStreamStats", - "RTCRtpTransceiver", - "RTCSctpTransport", - "RTCSessionDescription", - "RTCStatsReport", - "RTCTrackEvent", - "Sanitizer", - "Scheduler", - "Screen", - "ScreenOrientation", - "ScriptProcessorNode", - "SecurityPolicyViolationEvent", - "Selection", - "Sensor", - "SensorErrorEvent", - "Serial", - "SerialPort", - "ServiceWorker", - "ServiceWorkerContainer", - "ServiceWorkerGlobalScope", - "ServiceWorkerRegistration", - "ShadowRoot", - "SharedWorker", - "SharedWorkerGlobalScope", - "SourceBuffer", - "SourceBufferList", - "SpeechGrammar", - "SpeechGrammarList", - "SpeechRecognition", - "SpeechRecognitionAlternative", - "SpeechRecognitionErrorEvent", - "SpeechRecognitionEvent", - "SpeechRecognitionResult", - "SpeechRecognitionResultList", - "SpeechSynthesis", - "SpeechSynthesisErrorEvent", - "SpeechSynthesisEvent", - "SpeechSynthesisUtterance", - "SpeechSynthesisVoice", - "StaticRange", - "StereoPannerNode", - "Storage", - "StorageEvent", - "StorageManager", - "StylePropertyMap", - "StylePropertyMapReadOnly", - "StyleSheet", - "StyleSheetList", - "SubmitEvent", - "SubtleCrypto", - "SVGAElement", - "SVGAngle", - "SVGAnimateColorElement", - "SVGAnimatedAngle", - "SVGAnimatedBoolean", - "SVGAnimatedEnumeration", - "SVGAnimatedInteger", - "SVGAnimatedLength", - "SVGAnimatedLengthList", - "SVGAnimatedNumber", - "SVGAnimatedNumberList", - "SVGAnimatedPreserveAspectRatio", - "SVGAnimatedRect", - "SVGAnimatedString", - "SVGAnimatedTransformList", - "SVGAnimateElement", - "SVGAnimateMotionElement", - "SVGAnimateTransformElement", - "SVGAnimationElement", - "SVGCircleElement", - "SVGClipPathElement", - "SVGComponentTransferFunctionElement", - "SVGCursorElement", - "SVGDefsElement", - "SVGDescElement", - "SVGElement", - "SVGEllipseElement", - "SVGEvent", - "SVGFEBlendElement", - "SVGFEColorMatrixElement", - "SVGFEComponentTransferElement", - "SVGFECompositeElement", - "SVGFEConvolveMatrixElement", - "SVGFEDiffuseLightingElement", - "SVGFEDisplacementMapElement", - "SVGFEDistantLightElement", - "SVGFEDropShadowElement", - "SVGFEFloodElement", - "SVGFEFuncAElement", - "SVGFEFuncBElement", - "SVGFEFuncGElement", - "SVGFEFuncRElement", - "SVGFEGaussianBlurElement", - "SVGFEImageElement", - "SVGFEMergeElement", - "SVGFEMergeNodeElement", - "SVGFEMorphologyElement", - "SVGFEOffsetElement", - "SVGFEPointLightElement", - "SVGFESpecularLightingElement", - "SVGFESpotLightElement", - "SVGFETileElement", - "SVGFETurbulenceElement", - "SVGFilterElement", - "SVGFontElement", - "SVGFontFaceElement", - "SVGFontFaceFormatElement", - "SVGFontFaceNameElement", - "SVGFontFaceSrcElement", - "SVGFontFaceUriElement", - "SVGForeignObjectElement", - "SVGGElement", - "SVGGeometryElement", - "SVGGlyphElement", - "SVGGlyphRefElement", - "SVGGradientElement", - "SVGGraphicsElement", - "SVGHKernElement", - "SVGImageElement", - "SVGLength", - "SVGLengthList", - "SVGLinearGradientElement", - "SVGLineElement", - "SVGMarkerElement", - "SVGMaskElement", - "SVGMetadataElement", - "SVGMissingGlyphElement", - "SVGMPathElement", - "SVGNumber", - "SVGNumberList", - "SVGPathElement", - "SVGPatternElement", - "SVGPoint", - "SVGPointList", - "SVGPolygonElement", - "SVGPolylineElement", - "SVGPreserveAspectRatio", - "SVGRadialGradientElement", - "SVGRect", - "SVGRectElement", - "SVGRenderingIntent", - "SVGScriptElement", - "SVGSetElement", - "SVGStopElement", - "SVGStringList", - "SVGStyleElement", - "SVGSVGElement", - "SVGSwitchElement", - "SVGSymbolElement", - "SVGTextContentElement", - "SVGTextElement", - "SVGTextPathElement", - "SVGTextPositioningElement", - "SVGTitleElement", - "SVGTransform", - "SVGTransformList", - "SVGTRefElement", - "SVGTSpanElement", - "SVGUnitTypes", - "SVGUseElement", - "SVGViewElement", - "SVGVKernElement", - "SyncEvent", - "SyncManager", - "TaskAttributionTiming", - "TaskController", - "TaskPriorityChangeEvent", - "TaskSignal", - "Text", - "TextDecoder", - "TextDecoderStream", - "TextEncoder", - "TextEncoderStream", - "TextMetrics", - "TextTrack", - "TextTrackCue", - "TextTrackCueList", - "TextTrackList", - "TimeEvent", - "TimeRanges", - "ToggleEvent", - "Touch", - "TouchEvent", - "TouchList", - "TrackEvent", - "TransformStream", - "TransformStreamDefaultController", - "TransitionEvent", - "TreeWalker", - "TrustedHTML", - "TrustedScript", - "TrustedScriptURL", - "TrustedTypePolicy", - "TrustedTypePolicyFactory", - "UIEvent", - "URL", - "URLPattern", - "URLSearchParams", - "USB", - "USBAlternateInterface", - "USBConfiguration", - "USBConnectionEvent", - "USBDevice", - "USBEndpoint", - "USBInterface", - "USBInTransferResult", - "USBIsochronousInTransferPacket", - "USBIsochronousInTransferResult", - "USBIsochronousOutTransferPacket", - "USBIsochronousOutTransferResult", - "USBOutTransferResult", - "UserActivation", - "ValidityState", - "VideoColorSpace", - "VideoDecoder", - "VideoEncoder", - "VideoFrame", - "VideoPlaybackQuality", - "VideoTrack", - "VideoTrackList", - "ViewTransition", - "VirtualKeyboard", - "VisualViewport", - "VRDisplay", - "VRDisplayCapabilities", - "VRDisplayEvent", - "VREyeParameters", - "VRFieldOfView", - "VRFrameData", - "VRLayerInit", - "VRPose", - "VRStageParameters", - "VTTCue", - "VTTRegion", - "WakeLock", - "WakeLockSentinel", - "WaveShaperNode", - "WebGL2RenderingContext", - "WebGLActiveInfo", - "WebGLBuffer", - "WebGLContextEvent", - "WebGLFramebuffer", - "WebGLObject", - "WebGLProgram", - "WebGLQuery", - "WebGLRenderbuffer", - "WebGLRenderingContext", - "WebGLSampler", - "WebGLShader", - "WebGLShaderPrecisionFormat", - "WebGLSync", - "WebGLTexture", - "WebGLTransformFeedback", - "WebGLUniformLocation", - "WebGLVertexArrayObject", - "WebSocket", - "WebTransport", - "WebTransportBidirectionalStream", - "WebTransportDatagramDuplexStream", - "WebTransportError", - "WebTransportReceiveStream", - "WheelEvent", - "Window", - "WindowClient", - "WindowControlsOverlay", - "WindowControlsOverlayGeometryChangeEvent", - "Worker", - "WorkerGlobalScope", - "WorkerLocation", - "WorkerNavigator", - "Worklet", - "WorkletGlobalScope", - "WritableStream", - "WritableStreamDefaultController", - "WritableStreamDefaultWriter", - "XMLDocument", - "XMLHttpRequest", - "XMLHttpRequestEventTarget", - "XMLHttpRequestUpload", - "XMLSerializer", - "XPathEvaluator", - "XPathException", - "XPathExpression", - "XPathNSResolver", - "XPathResult", - "XRAnchor", - "XRAnchorSet", - "XRBoundedReferenceSpace", - "XRCompositionLayer", - "XRCPUDepthInformation", - "XRCubeLayer", - "XRCylinderLayer", - "XRDepthInformation", - "XREquirectLayer", - "XRFrame", - "XRHand", - "XRHitTestResult", - "XRHitTestSource", - "XRInputSource", - "XRInputSourceArray", - "XRInputSourceEvent", - "XRInputSourcesChangeEvent", - "XRJointPose", - "XRJointSpace", - "XRLayer", - "XRLayerEvent", - "XRLightEstimate", - "XRLightProbe", - "XRMediaBinding", - "XRPose", - "XRProjectionLayer", - "XRQuadLayer", - "XRRay", - "XRReferenceSpace", - "XRReferenceSpaceEvent", - "XRRenderState", - "XRRigidTransform", - "XRSession", - "XRSessionEvent", - "XRSpace", - "XRSubImage", - "XRSystem", - "XRTransientInputHitTestResult", - "XRTransientInputHitTestSource", - "XRView", - "XRViewerPose", - "XRViewport", - "XRWebGLBinding", - "XRWebGLDepthInformation", - "XRWebGLLayer", - "XRWebGLSubImage", - "XSLTProcessor" - ) - -} + val builtins: Set[String] = Set( + "AggregateError", + "Array", + "ArrayBuffer", + "AsyncFunction", + "AsyncGenerator", + "AsyncGeneratorFunction", + "AsyncIterator", + "Atomics", + "BigInt", + "BigInt64Array", + "BigUint64Array", + "Boolean", + "Buffer.from", + "DataView", + "Date", + "Error", + "EvalError", + "FinalizationRegistry", + "Float32Array", + "Float64Array", + "Function", + "Generator", + "GeneratorFunction", + "HTMLImageElement", + "Iterator", + "Infinity", + "Int16Array", + "Int32Array", + "Int8Array", + "InternalError", + "Intl", + "Intl.Collator", + "Intl.DateTimeFormat", + "Intl.DisplayNames", + "Intl.DurationFormat", + "Intl.ListFormat", + "Intl.Locale", + "Intl.NumberFormat", + "Intl.PluralRules", + "Intl.RelativeTimeFormat", + "Intl.Segmenter", + "JSON", + "JSON.parse", + "JSON.stringify", + "Map", + "Math", + "NaN", + "Number", + "Number.isFinite", + "Number.isInteger", + "Number.isNaN", + "Number.isSafeInteger", + "Number.parseFloat", + "Number.parseInt", + "Number.prototype.toExponential", + "Number.prototype.toFixed", + "Number.prototype.toLocaleString", + "Number.prototype.toPrecision", + "Number.prototype.toSource", + "Number.prototype.toString", + "Number.prototype.valueOf", + "Object", + "Object.assign", + "Object.create", + "Object.defineProperties", + "Object.defineProperty", + "Object.entries", + "Object.freeze", + "Object.fromEntries", + "Object.getOwnPropertyDescriptor", + "Object.getOwnPropertyDescriptors", + "Object.getOwnPropertyNames", + "Object.getOwnPropertySymbols", + "Object.getPrototypeOf", + "Object.is", + "Object.isExtensible", + "Object.isFrozen", + "Object.isSealed", + "Object.keys", + "Object.preventExtensions", + "Object.prototype.__defineGetter__", + "Object.prototype.__defineSetter__", + "Object.prototype.__lookupGetter__", + "Object.prototype.__lookupSetter__", + "Object.prototype.hasOwnProperty", + "Object.prototype.isPrototypeOf", + "Object.prototype.propertyIsEnumerable", + "Object.prototype.toLocaleString", + "Object.prototype.toSource", + "Object.prototype.toString", + "Object.prototype.valueOf", + "Object.seal", + "Object.setPrototypeOf", + "Object.values", + "Promise", + "Promise.all", + "Promise.allSettled", + "Promise.any", + "Promise.race", + "Promise.reject", + "Promise.resolve", + "Proxy", + "RangeError", + "ReferenceError", + "Reflect", + "RegExp", + "Set", + "SharedArrayBuffer", + "String", + "Symbol", + "SyntaxError", + "TypeError", + "TypedArray", + "URIError", + "Uint16Array", + "Uint32Array", + "Uint8Array", + "Uint8ClampedArray", + "WeakMap", + "WeakRef", + "WeakSet", + "decodeURI", + "decodeURIComponent", + "encodeURI", + "encodeURIComponent", + "escape", + "eval", + "eval", + "fetch", + "globalThis", + "isFinite", + "isNaN", + "localStorage.setItem", + "parseFloat", + "parseInt", + "undefined", + "unescape", + "uneval", + "AbortController", + "AbortSignal", + "AbsoluteOrientationSensor", + "AbstractRange", + "Accelerometer", + "AesCbcParams", + "AesCtrParams", + "AesGcmParams", + "AesKeyGenParams", + "AmbientLightSensor", + "AnalyserNode", + "ANGLE_instanced_arrays", + "Animation", + "AnimationEffect", + "AnimationEvent", + "AnimationPlaybackEvent", + "AnimationTimeline", + "Attr", + "AudioBuffer", + "AudioBufferSourceNode", + "AudioContext", + "AudioData", + "AudioDecoder", + "AudioDestinationNode", + "AudioEncoder", + "AudioListener", + "AudioNode", + "AudioParam", + "AudioParamDescriptor", + "AudioParamMap", + "AudioProcessingEvent", + "AudioScheduledSourceNode", + "AudioSinkInfo", + "AudioTrack", + "AudioTrackList", + "AudioWorklet", + "AudioWorkletGlobalScope", + "AudioWorkletNode", + "AudioWorkletProcessor", + "AuthenticatorAssertionResponse", + "AuthenticatorAttestationResponse", + "AuthenticatorResponse", + "BackgroundFetchEvent", + "BackgroundFetchManager", + "BackgroundFetchRecord", + "BackgroundFetchRegistration", + "BackgroundFetchUpdateUIEvent", + "BarcodeDetector", + "BarProp", + "BaseAudioContext", + "BatteryManager", + "BeforeInstallPromptEvent", + "BeforeUnloadEvent", + "BiquadFilterNode", + "Blob", + "BlobEvent", + "Bluetooth", + "BluetoothCharacteristicProperties", + "BluetoothDevice", + "BluetoothRemoteGATTCharacteristic", + "BluetoothRemoteGATTDescriptor", + "BluetoothRemoteGATTServer", + "BluetoothRemoteGATTService", + "BluetoothUUID", + "BroadcastChannel", + "ByteLengthQueuingStrategy", + "Cache", + "CacheStorage", + "CanMakePaymentEvent", + "CanvasCaptureMediaStreamTrack", + "CanvasGradient", + "CanvasPattern", + "CanvasRenderingContext2D", + "CaptureController", + "CaretPosition", + "CDATASection", + "ChannelMergerNode", + "ChannelSplitterNode", + "CharacterData", + "Client", + "Clients", + "Clipboard", + "ClipboardEvent", + "ClipboardItem", + "CloseEvent", + "Comment", + "CompositionEvent", + "CompressionStream", + "console", + "ConstantSourceNode", + "ContactAddress", + "ContactsManager", + "ContentIndex", + "ContentIndexEvent", + "ContentVisibilityAutoStateChangeEvent", + "ConvolverNode", + "CookieChangeEvent", + "CookieStore", + "CookieStoreManager", + "CountQueuingStrategy", + "Credential", + "CredentialsContainer", + "Crypto", + "CryptoKey", + "CryptoKeyPair", + "CSPViolationReportBody", + "CSS", + "CSSAnimation", + "CSSConditionRule", + "CSSContainerRule", + "CSSCounterStyleRule", + "CSSFontFaceRule", + "CSSFontFeatureValuesRule", + "CSSFontPaletteValuesRule", + "CSSGroupingRule", + "CSSImageValue", + "CSSImportRule", + "CSSKeyframeRule", + "CSSKeyframesRule", + "CSSKeywordValue", + "CSSLayerBlockRule", + "CSSLayerStatementRule", + "CSSMathInvert", + "CSSMathMax", + "CSSMathMin", + "CSSMathNegate", + "CSSMathProduct", + "CSSMathSum", + "CSSMathValue", + "CSSMatrixComponent", + "CSSMediaRule", + "CSSNamespaceRule", + "CSSNumericArray", + "CSSNumericValue", + "CSSPageRule", + "CSSPerspective", + "CSSPositionValue", + "CSSPrimitiveValue", + "CSSPropertyRule", + "CSSPseudoElement", + "CSSRotate", + "CSSRule", + "CSSRuleList", + "CSSScale", + "CSSSkew", + "CSSSkewX", + "CSSSkewY", + "CSSStyleDeclaration", + "CSSStyleRule", + "CSSStyleSheet", + "CSSStyleValue", + "CSSSupportsRule", + "CSSTransformComponent", + "CSSTransformValue", + "CSSTransition", + "CSSTranslate", + "CSSUnitValue", + "CSSUnparsedValue", + "CSSValue", + "CSSValueList", + "CSSVariableReferenceValue", + "CustomElementRegistry", + "CustomEvent", + "CustomStateSet", + "DataTransfer", + "DataTransferItem", + "DataTransferItemList", + "DecompressionStream", + "DedicatedWorkerGlobalScope", + "DelayNode", + "DeprecationReportBody", + "DeviceMotionEvent", + "DeviceMotionEventAcceleration", + "DeviceMotionEventRotationRate", + "DeviceOrientationEvent", + "DirectoryEntrySync", + "DirectoryReaderSync", + "Document", + "DocumentFragment", + "DocumentTimeline", + "DocumentType", + "DOMError", + "DOMException", + "DOMHighResTimeStamp", + "DOMImplementation", + "DOMMatrix (WebKitCSSMatrix)", + "DOMMatrixReadOnly", + "DOMParser", + "DOMPoint", + "DOMPointReadOnly", + "DOMQuad", + "DOMRect", + "DOMRectReadOnly", + "DOMStringList", + "DOMStringMap", + "DOMTokenList", + "DragEvent", + "DynamicsCompressorNode", + "EcdhKeyDeriveParams", + "EcdsaParams", + "EcKeyGenParams", + "EcKeyImportParams", + "Element", + "ElementInternals", + "EncodedAudioChunk", + "EncodedVideoChunk", + "ErrorEvent", + "Event", + "EventCounts", + "EventSource", + "EventTarget", + "ExtendableCookieChangeEvent", + "ExtendableEvent", + "ExtendableMessageEvent", + "EyeDropper", + "FeaturePolicy", + "FederatedCredential", + "FetchEvent", + "File", + "FileEntrySync", + "FileList", + "FileReader", + "FileReaderSync", + "FileSystem", + "FileSystemDirectoryEntry", + "FileSystemDirectoryHandle", + "FileSystemDirectoryReader", + "FileSystemEntry", + "FileSystemFileEntry", + "FileSystemFileHandle", + "FileSystemHandle", + "FileSystemSync", + "FileSystemSyncAccessHandle", + "FileSystemWritableFileStream", + "FocusEvent", + "FontData", + "FontFace", + "FontFaceSet", + "FontFaceSetLoadEvent", + "FormData", + "FormDataEvent", + "FragmentDirective", + "GainNode", + "Gamepad", + "GamepadButton", + "GamepadEvent", + "GamepadHapticActuator", + "GamepadPose", + "Geolocation", + "GeolocationCoordinates", + "GeolocationPosition", + "GeolocationPositionError", + "GestureEvent", + "GPU", + "GPUAdapter", + "GPUAdapterInfo", + "GPUBindGroup", + "GPUBindGroupLayout", + "GPUBuffer", + "GPUCanvasContext", + "GPUCommandBuffer", + "GPUCommandEncoder", + "GPUCompilationInfo", + "GPUCompilationMessage", + "GPUComputePassEncoder", + "GPUComputePipeline", + "GPUDevice", + "GPUDeviceLostInfo", + "GPUError", + "GPUExternalTexture", + "GPUInternalError", + "GPUOutOfMemoryError", + "GPUPipelineError", + "GPUPipelineLayout", + "GPUQuerySet", + "GPUQueue", + "GPURenderBundle", + "GPURenderBundleEncoder", + "GPURenderPassEncoder", + "GPURenderPipeline", + "GPUSampler", + "GPUShaderModule", + "GPUSupportedFeatures", + "GPUSupportedLimits", + "GPUTexture", + "GPUTextureView", + "GPUUncapturedErrorEvent", + "GPUValidationError", + "GravitySensor", + "Gyroscope", + "HashChangeEvent", + "Headers", + "HID", + "HIDConnectionEvent", + "HIDDevice", + "HIDInputReportEvent", + "Highlight", + "HighlightRegistry", + "History", + "HkdfParams", + "HmacImportParams", + "HmacKeyGenParams", + "HMDVRDevice", + "HTMLAnchorElement", + "HTMLAreaElement", + "HTMLAudioElement", + "HTMLBaseElement", + "HTMLBodyElement", + "HTMLBRElement", + "HTMLButtonElement", + "HTMLCanvasElement", + "HTMLCollection", + "HTMLDataElement", + "HTMLDataListElement", + "HTMLDetailsElement", + "HTMLDialogElement", + "HTMLDivElement", + "HTMLDListElement", + "HTMLDocument", + "HTMLElement", + "HTMLEmbedElement", + "HTMLFieldSetElement", + "HTMLFontElement", + "HTMLFormControlsCollection", + "HTMLFormElement", + "HTMLFrameSetElement", + "HTMLHeadElement", + "HTMLHeadingElement", + "HTMLHRElement", + "HTMLHtmlElement", + "HTMLIFrameElement", + "HTMLImageElement", + "HTMLInputElement", + "HTMLLabelElement", + "HTMLLegendElement", + "HTMLLIElement", + "HTMLLinkElement", + "HTMLMapElement", + "HTMLMarqueeElement", + "HTMLMediaElement", + "HTMLMenuElement", + "HTMLMenuItemElement", + "HTMLMetaElement", + "HTMLMeterElement", + "HTMLModElement", + "HTMLObjectElement", + "HTMLOListElement", + "HTMLOptGroupElement", + "HTMLOptionElement", + "HTMLOptionsCollection", + "HTMLOutputElement", + "HTMLParagraphElement", + "HTMLParamElement", + "HTMLPictureElement", + "HTMLPreElement", + "HTMLProgressElement", + "HTMLQuoteElement", + "HTMLScriptElement", + "HTMLSelectElement", + "HTMLSlotElement", + "HTMLSourceElement", + "HTMLSpanElement", + "HTMLStyleElement", + "HTMLTableCaptionElement", + "HTMLTableCellElement", + "HTMLTableColElement", + "HTMLTableElement", + "HTMLTableRowElement", + "HTMLTableSectionElement", + "HTMLTemplateElement", + "HTMLTextAreaElement", + "HTMLTimeElement", + "HTMLTitleElement", + "HTMLTrackElement", + "HTMLUListElement", + "HTMLUnknownElement", + "HTMLVideoElement", + "IDBCursor", + "IDBCursorWithValue", + "IDBDatabase", + "IDBFactory", + "IDBIndex", + "IDBKeyRange", + "IDBLocaleAwareKeyRange", + "IDBObjectStore", + "IDBOpenDBRequest", + "IDBRequest", + "IDBTransaction", + "IDBVersionChangeEvent", + "IdentityCredential", + "IdleDeadline", + "IdleDetector", + "IIRFilterNode", + "ImageBitmap", + "ImageBitmapRenderingContext", + "ImageCapture", + "ImageData", + "ImageDecoder", + "ImageTrack", + "ImageTrackList", + "Ink", + "InkPresenter", + "InputDeviceCapabilities", + "InputDeviceInfo", + "InputEvent", + "InstallEvent", + "IntersectionObserver", + "IntersectionObserverEntry", + "InterventionReportBody", + "Keyboard", + "KeyboardEvent", + "KeyboardLayoutMap", + "KeyframeEffect", + "LargestContentfulPaint", + "LaunchParams", + "LaunchQueue", + "LayoutShift", + "LayoutShiftAttribution", + "LinearAccelerationSensor", + "Location", + "Lock", + "LockManager", + "Magnetometer", + "MathMLElement", + "MediaCapabilities", + "MediaDeviceInfo", + "MediaDevices", + "MediaElementAudioSourceNode", + "MediaEncryptedEvent", + "MediaError", + "MediaImage", + "MediaKeyMessageEvent", + "MediaKeys", + "MediaKeySession", + "MediaKeyStatusMap", + "MediaKeySystemAccess", + "MediaList", + "MediaMetadata", + "MediaQueryList", + "MediaQueryListEvent", + "MediaRecorder", + "MediaRecorderErrorEvent", + "MediaSession", + "MediaSource", + "MediaSourceHandle", + "MediaStream", + "MediaStreamAudioDestinationNode", + "MediaStreamAudioSourceNode", + "MediaStreamEvent", + "MediaStreamTrack", + "MediaStreamTrackAudioSourceNode", + "MediaStreamTrackEvent", + "MediaStreamTrackGenerator", + "MediaStreamTrackProcessor", + "MediaTrackConstraints", + "MediaTrackSettings", + "MediaTrackSupportedConstraints", + "MerchantValidationEvent", + "MessageChannel", + "MessageEvent", + "MessagePort", + "Metadata", + "MIDIAccess", + "MIDIConnectionEvent", + "MIDIInput", + "MIDIInputMap", + "MIDIMessageEvent", + "MIDIOutput", + "MIDIOutputMap", + "MIDIPort", + "MimeType", + "MimeTypeArray", + "MouseEvent", + "MouseScrollEvent", + "MutationEvent", + "MutationObserver", + "MutationRecord", + "NamedNodeMap", + "NavigateEvent", + "Navigation", + "NavigationCurrentEntryChangeEvent", + "NavigationDestination", + "NavigationHistoryEntry", + "NavigationPreloadManager", + "NavigationTransition", + "Navigator", + "NavigatorUAData", + "NDEFMessage", + "NDEFReader", + "NDEFReadingEvent", + "NDEFRecord", + "NetworkInformation", + "Node", + "NodeIterator", + "NodeList", + "Notification", + "NotificationEvent", + "NotifyAudioAvailableEvent", + "OES_draw_buffers_indexed", + "OfflineAudioCompletionEvent", + "OfflineAudioContext", + "OffscreenCanvas", + "OffscreenCanvasRenderingContext2D", + "OrientationSensor", + "OscillatorNode", + "OTPCredential", + "OverconstrainedError", + "PageTransitionEvent", + "PaintWorkletGlobalScope", + "PannerNode", + "PasswordCredential", + "Path2D", + "PaymentAddress", + "PaymentManager", + "PaymentMethodChangeEvent", + "PaymentRequest", + "PaymentRequestEvent", + "PaymentRequestUpdateEvent", + "PaymentResponse", + "Pbkdf2Params", + "Performance", + "PerformanceElementTiming", + "PerformanceEntry", + "PerformanceEventTiming", + "PerformanceLongTaskTiming", + "PerformanceMark", + "PerformanceMeasure", + "PerformanceNavigation", + "PerformanceNavigationTiming", + "PerformanceObserver", + "PerformanceObserverEntryList", + "PerformancePaintTiming", + "PerformanceResourceTiming", + "PerformanceServerTiming", + "PerformanceTiming", + "PeriodicSyncEvent", + "PeriodicSyncManager", + "PeriodicWave", + "Permissions", + "PermissionStatus", + "PictureInPictureEvent", + "PictureInPictureWindow", + "Plugin", + "PluginArray", + "Point", + "PointerEvent", + "PopStateEvent", + "PositionSensorVRDevice", + "Presentation", + "PresentationAvailability", + "PresentationConnection", + "PresentationConnectionAvailableEvent", + "PresentationConnectionCloseEvent", + "PresentationConnectionList", + "PresentationReceiver", + "PresentationRequest", + "ProcessingInstruction", + "ProgressEvent", + "PromiseRejectionEvent", + "PublicKeyCredential", + "PushEvent", + "PushManager", + "PushMessageData", + "PushSubscription", + "PushSubscriptionOptions", + "RadioNodeList", + "Range", + "ReadableByteStreamController", + "ReadableStream", + "ReadableStreamBYOBReader", + "ReadableStreamBYOBRequest", + "ReadableStreamDefaultController", + "ReadableStreamDefaultReader", + "RelativeOrientationSensor", + "RemotePlayback", + "Report", + "ReportBody", + "ReportingObserver", + "Request", + "ResizeObserver", + "ResizeObserverEntry", + "ResizeObserverSize", + "Response", + "RsaHashedImportParams", + "RsaHashedKeyGenParams", + "RsaOaepParams", + "RsaPssParams", + "RTCAudioSourceStats", + "RTCCertificate", + "RTCDataChannel", + "RTCDataChannelEvent", + "RTCDtlsTransport", + "RTCDTMFSender", + "RTCDTMFToneChangeEvent", + "RTCError", + "RTCErrorEvent", + "RTCIceCandidate", + "RTCIceCandidatePair", + "RTCIceCandidatePairStats", + "RTCIceCandidateStats", + "RTCIceParameters", + "RTCIceServer", + "RTCIceTransport", + "RTCIdentityAssertion", + "RTCInboundRtpStreamStats", + "RTCOutboundRtpStreamStats", + "RTCPeerConnection", + "RTCPeerConnectionIceErrorEvent", + "RTCPeerConnectionIceEvent", + "RTCPeerConnectionStats", + "RTCRemoteOutboundRtpStreamStats", + "RTCRtpCodecParameters", + "RTCRtpContributingSource", + "RTCRtpEncodingParameters", + "RTCRtpReceiver", + "RTCRtpSender", + "RTCRtpStreamStats", + "RTCRtpTransceiver", + "RTCSctpTransport", + "RTCSessionDescription", + "RTCStatsReport", + "RTCTrackEvent", + "Sanitizer", + "Scheduler", + "Screen", + "ScreenOrientation", + "ScriptProcessorNode", + "SecurityPolicyViolationEvent", + "Selection", + "Sensor", + "SensorErrorEvent", + "Serial", + "SerialPort", + "ServiceWorker", + "ServiceWorkerContainer", + "ServiceWorkerGlobalScope", + "ServiceWorkerRegistration", + "ShadowRoot", + "SharedWorker", + "SharedWorkerGlobalScope", + "SourceBuffer", + "SourceBufferList", + "SpeechGrammar", + "SpeechGrammarList", + "SpeechRecognition", + "SpeechRecognitionAlternative", + "SpeechRecognitionErrorEvent", + "SpeechRecognitionEvent", + "SpeechRecognitionResult", + "SpeechRecognitionResultList", + "SpeechSynthesis", + "SpeechSynthesisErrorEvent", + "SpeechSynthesisEvent", + "SpeechSynthesisUtterance", + "SpeechSynthesisVoice", + "StaticRange", + "StereoPannerNode", + "Storage", + "StorageEvent", + "StorageManager", + "StylePropertyMap", + "StylePropertyMapReadOnly", + "StyleSheet", + "StyleSheetList", + "SubmitEvent", + "SubtleCrypto", + "SVGAElement", + "SVGAngle", + "SVGAnimateColorElement", + "SVGAnimatedAngle", + "SVGAnimatedBoolean", + "SVGAnimatedEnumeration", + "SVGAnimatedInteger", + "SVGAnimatedLength", + "SVGAnimatedLengthList", + "SVGAnimatedNumber", + "SVGAnimatedNumberList", + "SVGAnimatedPreserveAspectRatio", + "SVGAnimatedRect", + "SVGAnimatedString", + "SVGAnimatedTransformList", + "SVGAnimateElement", + "SVGAnimateMotionElement", + "SVGAnimateTransformElement", + "SVGAnimationElement", + "SVGCircleElement", + "SVGClipPathElement", + "SVGComponentTransferFunctionElement", + "SVGCursorElement", + "SVGDefsElement", + "SVGDescElement", + "SVGElement", + "SVGEllipseElement", + "SVGEvent", + "SVGFEBlendElement", + "SVGFEColorMatrixElement", + "SVGFEComponentTransferElement", + "SVGFECompositeElement", + "SVGFEConvolveMatrixElement", + "SVGFEDiffuseLightingElement", + "SVGFEDisplacementMapElement", + "SVGFEDistantLightElement", + "SVGFEDropShadowElement", + "SVGFEFloodElement", + "SVGFEFuncAElement", + "SVGFEFuncBElement", + "SVGFEFuncGElement", + "SVGFEFuncRElement", + "SVGFEGaussianBlurElement", + "SVGFEImageElement", + "SVGFEMergeElement", + "SVGFEMergeNodeElement", + "SVGFEMorphologyElement", + "SVGFEOffsetElement", + "SVGFEPointLightElement", + "SVGFESpecularLightingElement", + "SVGFESpotLightElement", + "SVGFETileElement", + "SVGFETurbulenceElement", + "SVGFilterElement", + "SVGFontElement", + "SVGFontFaceElement", + "SVGFontFaceFormatElement", + "SVGFontFaceNameElement", + "SVGFontFaceSrcElement", + "SVGFontFaceUriElement", + "SVGForeignObjectElement", + "SVGGElement", + "SVGGeometryElement", + "SVGGlyphElement", + "SVGGlyphRefElement", + "SVGGradientElement", + "SVGGraphicsElement", + "SVGHKernElement", + "SVGImageElement", + "SVGLength", + "SVGLengthList", + "SVGLinearGradientElement", + "SVGLineElement", + "SVGMarkerElement", + "SVGMaskElement", + "SVGMetadataElement", + "SVGMissingGlyphElement", + "SVGMPathElement", + "SVGNumber", + "SVGNumberList", + "SVGPathElement", + "SVGPatternElement", + "SVGPoint", + "SVGPointList", + "SVGPolygonElement", + "SVGPolylineElement", + "SVGPreserveAspectRatio", + "SVGRadialGradientElement", + "SVGRect", + "SVGRectElement", + "SVGRenderingIntent", + "SVGScriptElement", + "SVGSetElement", + "SVGStopElement", + "SVGStringList", + "SVGStyleElement", + "SVGSVGElement", + "SVGSwitchElement", + "SVGSymbolElement", + "SVGTextContentElement", + "SVGTextElement", + "SVGTextPathElement", + "SVGTextPositioningElement", + "SVGTitleElement", + "SVGTransform", + "SVGTransformList", + "SVGTRefElement", + "SVGTSpanElement", + "SVGUnitTypes", + "SVGUseElement", + "SVGViewElement", + "SVGVKernElement", + "SyncEvent", + "SyncManager", + "TaskAttributionTiming", + "TaskController", + "TaskPriorityChangeEvent", + "TaskSignal", + "Text", + "TextDecoder", + "TextDecoderStream", + "TextEncoder", + "TextEncoderStream", + "TextMetrics", + "TextTrack", + "TextTrackCue", + "TextTrackCueList", + "TextTrackList", + "TimeEvent", + "TimeRanges", + "ToggleEvent", + "Touch", + "TouchEvent", + "TouchList", + "TrackEvent", + "TransformStream", + "TransformStreamDefaultController", + "TransitionEvent", + "TreeWalker", + "TrustedHTML", + "TrustedScript", + "TrustedScriptURL", + "TrustedTypePolicy", + "TrustedTypePolicyFactory", + "UIEvent", + "URL", + "URLPattern", + "URLSearchParams", + "USB", + "USBAlternateInterface", + "USBConfiguration", + "USBConnectionEvent", + "USBDevice", + "USBEndpoint", + "USBInterface", + "USBInTransferResult", + "USBIsochronousInTransferPacket", + "USBIsochronousInTransferResult", + "USBIsochronousOutTransferPacket", + "USBIsochronousOutTransferResult", + "USBOutTransferResult", + "UserActivation", + "ValidityState", + "VideoColorSpace", + "VideoDecoder", + "VideoEncoder", + "VideoFrame", + "VideoPlaybackQuality", + "VideoTrack", + "VideoTrackList", + "ViewTransition", + "VirtualKeyboard", + "VisualViewport", + "VRDisplay", + "VRDisplayCapabilities", + "VRDisplayEvent", + "VREyeParameters", + "VRFieldOfView", + "VRFrameData", + "VRLayerInit", + "VRPose", + "VRStageParameters", + "VTTCue", + "VTTRegion", + "WakeLock", + "WakeLockSentinel", + "WaveShaperNode", + "WebGL2RenderingContext", + "WebGLActiveInfo", + "WebGLBuffer", + "WebGLContextEvent", + "WebGLFramebuffer", + "WebGLObject", + "WebGLProgram", + "WebGLQuery", + "WebGLRenderbuffer", + "WebGLRenderingContext", + "WebGLSampler", + "WebGLShader", + "WebGLShaderPrecisionFormat", + "WebGLSync", + "WebGLTexture", + "WebGLTransformFeedback", + "WebGLUniformLocation", + "WebGLVertexArrayObject", + "WebSocket", + "WebTransport", + "WebTransportBidirectionalStream", + "WebTransportDatagramDuplexStream", + "WebTransportError", + "WebTransportReceiveStream", + "WheelEvent", + "Window", + "WindowClient", + "WindowControlsOverlay", + "WindowControlsOverlayGeometryChangeEvent", + "Worker", + "WorkerGlobalScope", + "WorkerLocation", + "WorkerNavigator", + "Worklet", + "WorkletGlobalScope", + "WritableStream", + "WritableStreamDefaultController", + "WritableStreamDefaultWriter", + "XMLDocument", + "XMLHttpRequest", + "XMLHttpRequestEventTarget", + "XMLHttpRequestUpload", + "XMLSerializer", + "XPathEvaluator", + "XPathException", + "XPathExpression", + "XPathNSResolver", + "XPathResult", + "XRAnchor", + "XRAnchorSet", + "XRBoundedReferenceSpace", + "XRCompositionLayer", + "XRCPUDepthInformation", + "XRCubeLayer", + "XRCylinderLayer", + "XRDepthInformation", + "XREquirectLayer", + "XRFrame", + "XRHand", + "XRHitTestResult", + "XRHitTestSource", + "XRInputSource", + "XRInputSourceArray", + "XRInputSourceEvent", + "XRInputSourcesChangeEvent", + "XRJointPose", + "XRJointSpace", + "XRLayer", + "XRLayerEvent", + "XRLightEstimate", + "XRLightProbe", + "XRMediaBinding", + "XRPose", + "XRProjectionLayer", + "XRQuadLayer", + "XRRay", + "XRReferenceSpace", + "XRReferenceSpaceEvent", + "XRRenderState", + "XRRigidTransform", + "XRSession", + "XRSessionEvent", + "XRSpace", + "XRSubImage", + "XRSystem", + "XRTransientInputHitTestResult", + "XRTransientInputHitTestSource", + "XRView", + "XRViewerPose", + "XRViewport", + "XRWebGLBinding", + "XRWebGLDepthInformation", + "XRWebGLLayer", + "XRWebGLSubImage", + "XSLTProcessor" + ) +end GlobalBuiltins diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala index 81611512..6361dd86 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala @@ -1,125 +1,128 @@ package io.appthreat.jssrc2cpg.passes -import io.appthreat.x2cpg.passes.frontend.ImportsPass._ +import io.appthreat.x2cpg.passes.frontend.ImportsPass.* import io.appthreat.x2cpg.passes.frontend.XImportResolverPass -import io.appthreat.x2cpg.{Defines => XDefines} +import io.appthreat.x2cpg.{Defines as XDefines} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Method, MethodRef} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -import java.io.{File => JFile} +import java.io.{File as JFile} import java.util.regex.{Matcher, Pattern} import scala.util.{Failure, Success, Try} -class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { +class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg): - private val pathPattern = Pattern.compile("[\"']([\\w/.]+)[\"']") + private val pathPattern = Pattern.compile("[\"']([\\w/.]+)[\"']") - override protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit = { - val pathSep = ":" - val rawEntity = importedEntity.stripPrefix("./") - val alias = importedAs - val matcher = pathPattern.matcher(rawEntity) - val sep = Matcher.quoteReplacement(JFile.separator) - val root = s"$codeRoot${JFile.separator}" - val currentFile = s"$root$fileName" - // We want to know if the import is local since if an external name is used to match internal methods we may have - // false paths. - val isLocalImport = importedEntity.matches("^[.]+/?.*") - // TODO: At times there is an operation inside of a require, e.g. path.resolve(__dirname + "/../config/env/all.js") - // this tries to recover the string but does not perform string constant propagation - val entity = if (matcher.find()) matcher.group(1) else rawEntity - val resolvedPath = better.files - .File(currentFile.stripSuffix(currentFile.split(sep).last), entity.split(pathSep).head) - .pathAsString - .stripPrefix(root) + override protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit = + val pathSep = ":" + val rawEntity = importedEntity.stripPrefix("./") + val alias = importedAs + val matcher = pathPattern.matcher(rawEntity) + val sep = Matcher.quoteReplacement(JFile.separator) + val root = s"$codeRoot${JFile.separator}" + val currentFile = s"$root$fileName" + // We want to know if the import is local since if an external name is used to match internal methods we may have + // false paths. + val isLocalImport = importedEntity.matches("^[.]+/?.*") + // TODO: At times there is an operation inside of a require, e.g. path.resolve(__dirname + "/../config/env/all.js") + // this tries to recover the string but does not perform string constant propagation + val entity = if matcher.find() then matcher.group(1) else rawEntity + val resolvedPath = better.files + .File(currentFile.stripSuffix(currentFile.split(sep).last), entity.split(pathSep).head) + .pathAsString + .stripPrefix(root) - val isImportingModule = !entity.contains(pathSep) + val isImportingModule = !entity.contains(pathSep) - def targetModule = Try( - if (isLocalImport) - cpg - .file(s"${Pattern.quote(resolvedPath)}\\.?.*") - .method - .nameExact(":program") - else - Iterator.empty - ) match { - case Failure(_) => - logger.warn(s"Unable to resolve import due to irregular regex at '$importedEntity'") - Iterator.empty - case Success(modules) => modules - } + def targetModule = Try( + if isLocalImport then + cpg + .file(s"${Pattern.quote(resolvedPath)}\\.?.*") + .method + .nameExact(":program") + else + Iterator.empty + ) match + case Failure(_) => + logger.warn(s"Unable to resolve import due to irregular regex at '$importedEntity'") + Iterator.empty + case Success(modules) => modules - def targetAssignments = targetModule - .nameExact(":program") - .flatMap(_._callViaContainsOut) - .assignment + def targetAssignments = targetModule + .nameExact(":program") + .flatMap(_._callViaContainsOut) + .assignment - val matchingExports = if (isImportingModule) { - // If we are importing the whole module, we need to load all entities - targetAssignments - .code(s"\\_tmp\\_\\d+\\.\\w+ =.*", "(module\\.)?exports.*") - .dedup - .l - } else { - // If we are importing a specific entity, then we look for it here - targetAssignments - .code("^(module.)?exports.*") - .where(_.argument.codeExact(alias)) - .dedup - .l - } + val matchingExports = if isImportingModule then + // If we are importing the whole module, we need to load all entities + targetAssignments + .code(s"\\_tmp\\_\\d+\\.\\w+ =.*", "(module\\.)?exports.*") + .dedup + .l + else + // If we are importing a specific entity, then we look for it here + targetAssignments + .code("^(module.)?exports.*") + .where(_.argument.codeExact(alias)) + .dedup + .l - (if (matchingExports.nonEmpty) { - matchingExports.flatMap { exp => - exp.argument.l match { - case ::(expCall: Call, ::(b: Identifier, _)) - if expCall.code.matches("^(module.)?exports[.]?.*") && b.name == alias => - val moduleMethods = targetModule.repeat(_.astChildren.isMethod)(_.emit).l - lazy val methodMatches = moduleMethods.name(b.name).l - lazy val constructorMatches = - moduleMethods.fullName(s".*${b.name}$pathSep${XDefines.ConstructorMethodName}$$").l - lazy val moduleExportsThisVariable = moduleMethods.body.local - .where(_.nameExact(b.name)) - .nonEmpty - // Exported function with only the name of the function - val methodPaths = - if (methodMatches.nonEmpty) methodMatches.fullName.toSet - else constructorMatches.fullName.toSet - if (methodPaths.nonEmpty) { - methodPaths.flatMap(x => Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x))) - } else if (moduleExportsThisVariable) { - Set(ResolvedMember(targetModule.fullName.head, b.name)) - } else { - Set.empty - } - case ::(x: Call, ::(b: MethodRef, _)) => - // Exported function with a method ref of the function - val methodName = x.argumentOption(2).map(_.code).getOrElse(b.referencedMethod.name) - val (callName, receiver) = - if (methodName == "exports") (alias, Option("this")) else (methodName, Option(alias)) - b.referencedMethod.astParent.iterator - .collectAll[Method] - .fullName - .map(x => ResolvedTypeDecl(x)) - .toSet ++ Set(ResolvedMethod(b.methodFullName, callName, receiver)) - case ::(_, ::(y: Call, _)) => - // Exported closure with a method ref within the AST of the RHS - y.ast.isMethodRef.map(mRef => ResolvedMethod(mRef.methodFullName, alias, Option("this"))).toSet - case _ => - Set.empty[ResolvedImport] - } - }.toSet - } else { - Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity)) - }).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) - } - -} + (if matchingExports.nonEmpty then + matchingExports.flatMap { exp => + exp.argument.l match + case ::(expCall: Call, ::(b: Identifier, _)) + if expCall.code.matches("^(module.)?exports[.]?.*") && b.name == alias => + val moduleMethods = targetModule.repeat(_.astChildren.isMethod)(_.emit).l + lazy val methodMatches = moduleMethods.name(b.name).l + lazy val constructorMatches = + moduleMethods.fullName( + s".*${b.name}$pathSep${XDefines.ConstructorMethodName}$$" + ).l + lazy val moduleExportsThisVariable = moduleMethods.body.local + .where(_.nameExact(b.name)) + .nonEmpty + // Exported function with only the name of the function + val methodPaths = + if methodMatches.nonEmpty then methodMatches.fullName.toSet + else constructorMatches.fullName.toSet + if methodPaths.nonEmpty then + methodPaths.flatMap(x => + Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x)) + ) + else if moduleExportsThisVariable then + Set(ResolvedMember(targetModule.fullName.head, b.name)) + else + Set.empty + case ::(x: Call, ::(b: MethodRef, _)) => + // Exported function with a method ref of the function + val methodName = + x.argumentOption(2).map(_.code).getOrElse(b.referencedMethod.name) + val (callName, receiver) = + if methodName == "exports" then (alias, Option("this")) + else (methodName, Option(alias)) + b.referencedMethod.astParent.iterator + .collectAll[Method] + .fullName + .map(x => ResolvedTypeDecl(x)) + .toSet ++ Set(ResolvedMethod(b.methodFullName, callName, receiver)) + case ::(_, ::(y: Call, _)) => + // Exported closure with a method ref within the AST of the RHS + y.ast.isMethodRef.map(mRef => + ResolvedMethod(mRef.methodFullName, alias, Option("this")) + ).toSet + case _ => + Set.empty[ResolvedImport] + }.toSet + else + Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity)) + ).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) + end optionalResolveImport +end ImportResolverPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala index 9333b3d9..ed038eda 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala @@ -4,24 +4,24 @@ import io.appthreat.x2cpg.X2Cpg import io.appthreat.x2cpg.passes.frontend.XImportsPass import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment -/** This pass creates `IMPORT` nodes by looking for calls to `require`. `IMPORT` nodes are linked to existing dependency - * nodes, or, if no suitable dependency node exists, a dependency node is created. +/** This pass creates `IMPORT` nodes by looking for calls to `require`. `IMPORT` nodes are linked to + * existing dependency nodes, or, if no suitable dependency node exists, a dependency node is + * created. * - * TODO with this, we can have multiple IMPORT nodes that point to the same call: one created during AST creation, and - * one using this pass. + * TODO with this, we can have multiple IMPORT nodes that point to the same call: one created + * during AST creation, and one using this pass. * * TODO Dependency node creation is still missing. */ -class ImportsPass(cpg: Cpg) extends XImportsPass(cpg) { +class ImportsPass(cpg: Cpg) extends XImportsPass(cpg): - override protected val importCallName: String = "require" + override protected val importCallName: String = "require" - override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = - x.inAssignment.codeNot("var .*").map(y => (x, y)) + override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = + x.inAssignment.codeNot("var .*").map(y => (x, y)) - override protected def importedEntityFromCall(call: Call): String = X2Cpg.stripQuotes(call.argument(1).code) - -} + override protected def importedEntityFromCall(call: Call): String = + X2Cpg.stripQuotes(call.argument(1).code) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala index 2f7cd82c..64f0ff1f 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala @@ -3,13 +3,11 @@ package io.appthreat.jssrc2cpg.passes import io.appthreat.x2cpg.passes.frontend.XInheritanceFullNamePass import io.shiftleft.codepropertygraph.Cpg -/** Using some basic heuristics, will try to resolve type full names from types found within the CPG. Requires - * ImportPass as a pre-requisite. +/** Using some basic heuristics, will try to resolve type full names from types found within the + * CPG. Requires ImportPass as a pre-requisite. */ -class JavaScriptInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg) { +class JavaScriptInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg): - override val pathSep: Char = ':' - override val moduleName: String = ":program" - override val fileExt: String = ".js" - -} + override val pathSep: Char = ':' + override val moduleName: String = ":program" + override val fileExt: String = ".js" diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala index f2b12c52..afe37ac9 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala @@ -5,12 +5,10 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.semanticcpg.language.* -class JavaScriptTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { +class JavaScriptTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override protected val pathSep = ':' + override protected val pathSep = ':' - override protected def calls: Iterator[Call] = cpg.call - .or(_.nameNot(".*", ".*"), _.name(".new")) - .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) - -} + override protected def calls: Iterator[Call] = cpg.call + .or(_.nameNot(".*", ".*"), _.name(".new")) + .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala index db2b6226..18baf5d6 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala @@ -11,189 +11,214 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder class JavaScriptTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new JavaScriptTypeRecovery(cpg, state) -} - -private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { - - override def compilationUnit: Iterator[File] = cpg.file.iterator - - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForJavaScriptFile(cpg, unit, builder, state.copy(config = newConfig)) - } - -} - -private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { - - override protected val pathSep = ':' - - /** A heuristic method to determine if a call is a constructor or not. - */ - override protected def isConstructor(c: Call): Boolean = { - c.name.endsWith("factory") && c.inCall.astParent.headOption.exists(_.isInstanceOf[Block]) - } - - override protected def isConstructor(name: String): Boolean = - !name.isBlank && (name.charAt(0).isUpper || name.endsWith("factory")) - - override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match { - case x @ (_: Identifier | _: Local | _: MethodParameterIn) - if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => - val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) - val typeHints = symbolTable.get(LocalVar(x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any))) - typeFullName - lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet - val resolvedTypeHints = - if (typeHints.nonEmpty) symbolTable.put(x, typeHints) - else if (cpgTypeFullName.nonEmpty) symbolTable.put(x, cpgTypeFullName) - else symbolTable.put(x, x.getKnownTypes) - if (!resolvedTypeHints.contains(typeFullName) && resolvedTypeHints.sizeIs == 1) - builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) - case x @ (_: Identifier | _: Local | _: MethodParameterIn) => - symbolTable.put(x, x.getKnownTypes) - case x: Call => symbolTable.put(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) - case _ => - } - - override protected def prepopulateSymbolTable(): Unit = { - super.prepopulateSymbolTable() - cu.ast.isMethod.foreach(f => symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName))) - (cu.ast.isParameter.whereNot(_.nameExact("this")) ++ cu.ast.isMethod.methodReturn).filter(hasTypes).foreach { p => - val resolvedHints = p.getKnownTypes - .map { t => - t.split("\\.").headOption match { - case Some(base) if symbolTable.contains(LocalVar(base)) => - (t, symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}")) - case _ => (t, Set(t)) - } + extends XTypeRecoveryPass[File](cpg, config): + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = + new JavaScriptTypeRecovery(cpg, state) + +private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) + extends XTypeRecovery[File](cpg, state): + + override def compilationUnit: Iterator[File] = cpg.file.iterator + + override def generateRecoveryForCompilationUnitTask( + unit: File, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[File] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForJavaScriptFile(cpg, unit, builder, state.copy(config = newConfig)) + +private class RecoverForJavaScriptFile( + cpg: Cpg, + cu: File, + builder: DiffGraphBuilder, + state: XTypeRecoveryState +) extends RecoverForXCompilationUnit[File](cpg, cu, builder, state): + + override protected val pathSep = ':' + + /** A heuristic method to determine if a call is a constructor or not. + */ + override protected def isConstructor(c: Call): Boolean = + c.name.endsWith("factory") && c.inCall.astParent.headOption.exists(_.isInstanceOf[Block]) + + override protected def isConstructor(name: String): Boolean = + !name.isBlank && (name.charAt(0).isUpper || name.endsWith("factory")) + + override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match + case x @ (_: Identifier | _: Local | _: MethodParameterIn) + if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => + val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) + val typeHints = symbolTable.get(LocalVar(x.property( + PropertyNames.TYPE_FULL_NAME, + Defines.Any + ))) - typeFullName + lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet + val resolvedTypeHints = + if typeHints.nonEmpty then symbolTable.put(x, typeHints) + else if cpgTypeFullName.nonEmpty then symbolTable.put(x, cpgTypeFullName) + else symbolTable.put(x, x.getKnownTypes) + if !resolvedTypeHints.contains(typeFullName) && resolvedTypeHints.sizeIs == 1 then + builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) + case x @ (_: Identifier | _: Local | _: MethodParameterIn) => + symbolTable.put(x, x.getKnownTypes) + case x: Call => symbolTable.put(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case _ => + + override protected def prepopulateSymbolTable(): Unit = + super.prepopulateSymbolTable() + cu.ast.isMethod.foreach(f => + symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName)) + ) + (cu.ast.isParameter.whereNot(_.nameExact("this")) ++ cu.ast.isMethod.methodReturn).filter( + hasTypes + ).foreach { p => + val resolvedHints = p.getKnownTypes + .map { t => + t.split("\\.").headOption match + case Some(base) if symbolTable.contains(LocalVar(base)) => + ( + t, + symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}") + ) + case _ => (t, Set(t)) + } + .flatMap { + case (t, ts) if Set(t) == ts => Set(t) + case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) + } + p match + case _: MethodParameterIn => symbolTable.put(p, resolvedHints) + case _: MethodReturn if resolvedHints.sizeIs == 1 => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) + case _: MethodReturn => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) + builder.setNodeProperty( + p, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + resolvedHints + ) + case _ => } - .flatMap { - case (t, ts) if Set(t) == ts => Set(t) - case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) - } - p match { - case _: MethodParameterIn => symbolTable.put(p, resolvedHints) - case _: MethodReturn if resolvedHints.sizeIs == 1 => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) - case _: MethodReturn => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) - builder.setNodeProperty(p, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, resolvedHints) - case _ => - } - } - } - - private lazy val exportedIdentifiers = cu.method - .nameExact(":program") - .flatMap(_._callViaContainsOut) - .nameExact(Operators.assignment) - .filter(_.code.startsWith("exports.*")) - .argument - .isIdentifier - .name - .toSet - - override protected def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate(i.id(), exportedIdentifiers.contains(i.name) || super.isField(i)) - - override protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { - val constructorPaths = if (c.methodFullName.endsWith(".alloc")) { - def newChildren = c.inAssignment.astSiblings.isCall.nameExact(".new").astChildren - val possibleImportIdentifier = newChildren.isIdentifier.headOption match { - case Some(i) if GlobalBuiltins.builtins.contains(i.name) => Set(s"__ecma.${i.name}") - case Some(i) => symbolTable.get(i) - case None => Set.empty[String] - } - lazy val possibleConstructorPointer = - newChildren.astChildren.isFieldIdentifier.map(f => CallAlias(f.canonicalName, Some("this"))).headOption match { - case Some(fi) => symbolTable.get(fi) - case None => Set.empty[String] - } - - if (possibleImportIdentifier.nonEmpty) possibleImportIdentifier - else if (possibleConstructorPointer.nonEmpty) possibleConstructorPointer - else Set.empty[String] - } else (symbolTable.get(c) + c.methodFullName).map(t => t.stripSuffix(".factory")) - associateTypes(i, constructorPaths) - } - - override protected def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { - operation match { - case ".new" => - c.astChildren.l match { - case ::(fa: Call, ::(i: Identifier, _)) if fa.name == Operators.fieldAccess => - symbolTable.append( - c, - visitIdentifierAssignedToFieldLoad(i, new FieldAccess(fa)).map(t => s"$t$pathSep$ConstructorMethodName") + end prepopulateSymbolTable + + private lazy val exportedIdentifiers = cu.method + .nameExact(":program") + .flatMap(_._callViaContainsOut) + .nameExact(Operators.assignment) + .filter(_.code.startsWith("exports.*")) + .argument + .isIdentifier + .name + .toSet + + override protected def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + exportedIdentifiers.contains(i.name) || super.isField(i) + ) + + override protected def visitIdentifierAssignedToConstructor( + i: Identifier, + c: Call + ): Set[String] = + val constructorPaths = if c.methodFullName.endsWith(".alloc") then + def newChildren = + c.inAssignment.astSiblings.isCall.nameExact(".new").astChildren + val possibleImportIdentifier = newChildren.isIdentifier.headOption match + case Some(i) if GlobalBuiltins.builtins.contains(i.name) => Set(s"__ecma.${i.name}") + case Some(i) => symbolTable.get(i) + case None => Set.empty[String] + lazy val possibleConstructorPointer = + newChildren.astChildren.isFieldIdentifier.map(f => + CallAlias(f.canonicalName, Some("this")) + ).headOption match + case Some(fi) => symbolTable.get(fi) + case None => Set.empty[String] + + if possibleImportIdentifier.nonEmpty then possibleImportIdentifier + else if possibleConstructorPointer.nonEmpty then possibleConstructorPointer + else Set.empty[String] + else (symbolTable.get(c) + c.methodFullName).map(t => t.stripSuffix(".factory")) + associateTypes(i, constructorPaths) + end visitIdentifierAssignedToConstructor + + override protected def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case ".new" => + c.astChildren.l match + case ::(fa: Call, ::(i: Identifier, _)) if fa.name == Operators.fieldAccess => + symbolTable.append( + c, + visitIdentifierAssignedToFieldLoad(i, new FieldAccess(fa)).map(t => + s"$t$pathSep$ConstructorMethodName" + ) + ) + case _ => Set.empty + case _ => super.visitIdentifierAssignedToOperator(i, c, operation) + + override protected def associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] + ): Set[String] = + if symbolTable.contains(LocalVar(fieldName)) then + val fieldTypes = symbolTable.get(LocalVar(fieldName)) + symbolTable.append(i, fieldTypes) + else if symbolTable.contains(CallAlias(fieldName, Option("this"))) then + symbolTable.get(CallAlias(fieldName, Option("this"))) + else + super.associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] ) - case _ => Set.empty - } - case _ => super.visitIdentifierAssignedToOperator(i, c, operation) - } - } - - override protected def associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] - ): Set[String] = { - if (symbolTable.contains(LocalVar(fieldName))) { - val fieldTypes = symbolTable.get(LocalVar(fieldName)) - symbolTable.append(i, fieldTypes) - } else if (symbolTable.contains(CallAlias(fieldName, Option("this")))) { - symbolTable.get(CallAlias(fieldName, Option("this"))) - } else { - super.associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] - ) - } - } - - override protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = - if (c.name == "require") Set.empty - else super.visitIdentifierAssignedToCall(i, c) - - override protected def visitIdentifierAssignedToMethodRef( - i: Identifier, - m: MethodRef, - rec: Option[String] = None - ): Set[String] = - super.visitIdentifierAssignedToMethodRef(i, m, Option("this")) - - override protected def visitIdentifierAssignedToTypeRef( - i: Identifier, - t: TypeRef, - rec: Option[String] = None - ): Set[String] = - super.visitIdentifierAssignedToTypeRef(i, t, Option("this")) - - override protected def postSetTypeInformation(): Unit = { - // often there are "this" identifiers with type hints but this can be set to a type hint if they meet the criteria - cu.method - .flatMap(_._identifierViaContainsOut) - .nameExact("this") - .where(_.typeFullNameExact(Defines.Any)) - .filterNot(_.dynamicTypeHintFullName.isEmpty) - .foreach(setTypeFromTypeHints) - } - - protected override def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = - super.storeIdentifierTypeInfo(i, types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}"))) - - protected override def storeLocalTypeInfo(i: Local, types: Seq[String]): Unit = - super.storeLocalTypeInfo(i, types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}"))) - -} + + override protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + if c.name == "require" then Set.empty + else super.visitIdentifierAssignedToCall(i, c) + + override protected def visitIdentifierAssignedToMethodRef( + i: Identifier, + m: MethodRef, + rec: Option[String] = None + ): Set[String] = + super.visitIdentifierAssignedToMethodRef(i, m, Option("this")) + + override protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] = None + ): Set[String] = + super.visitIdentifierAssignedToTypeRef(i, t, Option("this")) + + override protected def postSetTypeInformation(): Unit = + // often there are "this" identifiers with type hints but this can be set to a type hint if they meet the criteria + cu.method + .flatMap(_._identifierViaContainsOut) + .nameExact("this") + .where(_.typeFullNameExact(Defines.Any)) + .filterNot(_.dynamicTypeHintFullName.isEmpty) + .foreach(setTypeFromTypeHints) + + protected override def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + super.storeIdentifierTypeInfo( + i, + types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) + ) + + protected override def storeLocalTypeInfo(i: Local, types: Seq[String]): Unit = + super.storeLocalTypeInfo( + i, + types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) + ) +end RecoverForJavaScriptFile diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala index 8810f266..8a850409 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala @@ -6,12 +6,11 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewMetaData import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.passes.CpgPass -class JsMetaDataPass(cpg: Cpg, hash: String, inputPath: String) extends CpgPass(cpg) { +class JsMetaDataPass(cpg: Cpg, hash: String, inputPath: String) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = { - val absolutePathToRoot = File(inputPath).path.toAbsolutePath.toString - val metaNode = NewMetaData().language(Languages.JSSRC).root(absolutePathToRoot).hash(hash).version("0.1") - diffGraph.addNode(metaNode) - } - -} + override def run(diffGraph: DiffGraphBuilder): Unit = + val absolutePathToRoot = File(inputPath).path.toAbsolutePath.toString + val metaNode = NewMetaData().language(Languages.JSSRC).root(absolutePathToRoot).hash( + hash + ).version("0.1") + diffGraph.addNode(metaNode) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala index 4aeda6a7..1e2a355b 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala @@ -9,19 +9,17 @@ import io.shiftleft.utils.IOUtils import scala.util.matching.Regex class PrivateKeyFilePass(cpg: Cpg, config: Config, report: Report = new Report()) - extends ConfigPass(cpg, config, report) { + extends ConfigPass(cpg, config, report): - private val PrivateKeyRegex: Regex = """.*RSA\sPRIVATE\sKEY.*""".r + private val PrivateKeyRegex: Regex = """.*RSA\sPRIVATE\sKEY.*""".r - override val allExtensions: Set[String] = Set(".key") - override val selectedExtensions: Set[String] = Set(".key") + override val allExtensions: Set[String] = Set(".key") + override val selectedExtensions: Set[String] = Set(".key") - override def fileContent(file: File): Seq[String] = - Seq("Content omitted for security reasons.") + override def fileContent(file: File): Seq[String] = + Seq("Content omitted for security reasons.") - override def generateParts(): Array[File] = - configFiles(config, selectedExtensions).toArray.filter(p => - IOUtils.readLinesInFile(p.path).exists(PrivateKeyRegex.matches) - ) - -} + override def generateParts(): Array[File] = + configFiles(config, selectedExtensions).toArray.filter(p => + IOUtils.readLinesInFile(p.path).exists(PrivateKeyRegex.matches) + ) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala index 37d095ba..d761278c 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala @@ -4,15 +4,13 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.NewType import io.shiftleft.passes.CpgPass -class TypeNodePass(usedTypes: List[(String, String)], cpg: Cpg) extends CpgPass(cpg, "types") { - override def run(diffGraph: DiffGraphBuilder): Unit = { - val filteredTypes = usedTypes.filterNot { case (name, _) => - name == Defines.Any || Defines.JsTypes.contains(name) - } +class TypeNodePass(usedTypes: List[(String, String)], cpg: Cpg) extends CpgPass(cpg, "types"): + override def run(diffGraph: DiffGraphBuilder): Unit = + val filteredTypes = usedTypes.filterNot { case (name, _) => + name == Defines.Any || Defines.JsTypes.contains(name) + } - filteredTypes.sortBy(_._2).foreach { case (name, fullName) => - val typeNode = NewType().name(name).fullName(fullName).typeDeclFullName(fullName) - diffGraph.addNode(typeNode) - } - } -} + filteredTypes.sortBy(_._2).foreach { case (name, fullName) => + val typeNode = NewType().name(name).fullName(fullName).typeDeclFullName(fullName) + diffGraph.addNode(typeNode) + } diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala index 11edeaac..a84bd7c7 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala @@ -2,76 +2,83 @@ package io.appthreat.jssrc2cpg.preprocessing import scala.collection.mutable -class EjsPreprocessor { +class EjsPreprocessor: - private val CommentTag = "<%#" - private val TagGroupsRegex = """(<%[=\-_#]?)([\s\S]*?)([-_#]?%>)""".r - private val ScriptGroupsRegex = """()""".r - private val Tags = List("<%#", "<%=", "<%-", "<%_", "-%>", "_%>", "#%>", "%>") + private val CommentTag = "<%#" + private val TagGroupsRegex = """(<%[=\-_#]?)([\s\S]*?)([-_#]?%>)""".r + private val ScriptGroupsRegex = """()""".r + private val Tags = List("<%#", "<%=", "<%-", "<%_", "-%>", "_%>", "#%>", "%>") - private def stripScriptTag(code: String): String = { - var x = code.replaceAll("", "%> ") - ScriptGroupsRegex.findAllIn(code).matchData.foreach { ma => - var scriptBlock = ma.group(2) - val matches = TagGroupsRegex.findAllIn(scriptBlock).matchData.toList - matches.foreach { - case mat if mat.group(1) == "<%" && mat.group(3) == "-%>" => - scriptBlock = scriptBlock.replace(mat.toString(), " " * mat.toString().replaceAll("[^\\s]", " ").length) - case _ => - } - Tags.foreach { tag => - scriptBlock = scriptBlock.replaceAll(tag, " " * tag.length) - } - x = x.replace(ma.group(2), scriptBlock) - } - x - } - - private def needsSemicolon(code: String): Boolean = - !code.trim.endsWith("{") && !code.trim.endsWith("}") && !code.trim.endsWith(";") + private def stripScriptTag(code: String): String = + var x = code.replaceAll("", "%> ") + ScriptGroupsRegex.findAllIn(code).matchData.foreach { ma => + var scriptBlock = ma.group(2) + val matches = TagGroupsRegex.findAllIn(scriptBlock).matchData.toList + matches.foreach { + case mat if mat.group(1) == "<%" && mat.group(3) == "-%>" => + scriptBlock = scriptBlock.replace( + mat.toString(), + " " * mat.toString().replaceAll("[^\\s]", " ").length + ) + case _ => + } + Tags.foreach { tag => + scriptBlock = scriptBlock.replaceAll(tag, " " * tag.length) + } + x = x.replace(ma.group(2), scriptBlock) + } + x + end stripScriptTag - def preprocess(code: String): String = { - val codeWithoutScriptTag = stripScriptTag(code) - val codeAsCharArray = codeWithoutScriptTag.toCharArray - val preprocessedCode = new mutable.StringBuilder(codeAsCharArray.length) - val matches = TagGroupsRegex.findAllIn(codeWithoutScriptTag).matchData.toList + private def needsSemicolon(code: String): Boolean = + !code.trim.endsWith("{") && !code.trim.endsWith("}") && !code.trim.endsWith(";") - val positions = matches.flatMap { - case ma if ma.group(1) == CommentTag => None // ignore comments - case ma if ma.group(2).trim.startsWith("include ") => None // ignore including other ejs templates - case ma => - val start = ma.start + ma.group(1).length - val end = ma.end - ma.group(3).length - Option((start, end)) - } + def preprocess(code: String): String = + val codeWithoutScriptTag = stripScriptTag(code) + val codeAsCharArray = codeWithoutScriptTag.toCharArray + val preprocessedCode = new mutable.StringBuilder(codeAsCharArray.length) + val matches = TagGroupsRegex.findAllIn(codeWithoutScriptTag).matchData.toList - codeAsCharArray.zipWithIndex.foreach { - case (currChar, _) if currChar == '\n' || currChar == '\r' => - preprocessedCode.append(currChar) - case (currChar, index) if positions.exists { case (start, end) => index >= start && index < end } => - preprocessedCode.append(currChar) - case _ => - preprocessedCode.append(" ") - } + val positions = matches.flatMap { + case ma if ma.group(1) == CommentTag => None // ignore comments + case ma if ma.group(2).trim.startsWith("include ") => + None // ignore including other ejs templates + case ma => + val start = ma.start + ma.group(1).length + val end = ma.end - ma.group(3).length + Option((start, end)) + } - var codeWithoutSemicolon = preprocessedCode.toString() - val alreadyReplaced = mutable.ArrayBuffer.empty[(Int, Int)] - matches.foreach { - case ma if ma.group(1) == CommentTag => // ignore comments - case ma if ma.group(2).trim.startsWith("include ") => // ignore including other ejs templates - case ma if needsSemicolon(ma.group(2)) => - val start = ma.start + ma.group(1).length - val end = ma.end - ma.group(3).length - if (!alreadyReplaced.contains((start, end))) { - val replacementCode = s"${ma.group(2)};" - codeWithoutSemicolon = - s"${codeWithoutSemicolon.substring(0, start)}$replacementCode${codeWithoutSemicolon.substring(end + 1, codeWithoutSemicolon.length)}" - alreadyReplaced.append((start, end)) + codeAsCharArray.zipWithIndex.foreach { + case (currChar, _) if currChar == '\n' || currChar == '\r' => + preprocessedCode.append(currChar) + case (currChar, index) if positions.exists { case (start, end) => + index >= start && index < end + } => + preprocessedCode.append(currChar) + case _ => + preprocessedCode.append(" ") } - case _ => // others are fine already - } - codeWithoutSemicolon - } + var codeWithoutSemicolon = preprocessedCode.toString() + val alreadyReplaced = mutable.ArrayBuffer.empty[(Int, Int)] + matches.foreach { + case ma if ma.group(1) == CommentTag => // ignore comments + case ma + if ma.group(2).trim.startsWith( + "include " + ) => // ignore including other ejs templates + case ma if needsSemicolon(ma.group(2)) => + val start = ma.start + ma.group(1).length + val end = ma.end - ma.group(3).length + if !alreadyReplaced.contains((start, end)) then + val replacementCode = s"${ma.group(2)};" + codeWithoutSemicolon = + s"${codeWithoutSemicolon.substring(0, start)}$replacementCode${codeWithoutSemicolon.substring(end + 1, codeWithoutSemicolon.length)}" + alreadyReplaced.append((start, end)) + case _ => // others are fine already + } -} + codeWithoutSemicolon + end preprocess +end EjsPreprocessor diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala index b91cfd30..41246597 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala @@ -16,280 +16,271 @@ import scala.util.Success import scala.util.matching.Regex import scala.util.Try -object AstGenRunner { +object AstGenRunner: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - private val LineLengthThreshold: Int = 10000 + private val LineLengthThreshold: Int = 10000 - private val TypeDefinitionFileExtensions = List(".t.ts", ".d.ts") + private val TypeDefinitionFileExtensions = List(".t.ts", ".d.ts") - private val MinifiedPathRegex: Regex = ".*([.-]min\\..*js|bundle\\.js)".r + private val MinifiedPathRegex: Regex = ".*([.-]min\\..*js|bundle\\.js)".r - private val IgnoredTestsRegex: Seq[Regex] = - List( - ".*[.-]spec\\.js".r, - ".*[.-]mock\\.js".r, - ".*[.-]e2e\\.js".r, - ".*[.-]test\\.js".r, - ".*cypress\\.json".r, - ".*test.*\\.json".r + private val IgnoredTestsRegex: Seq[Regex] = + List( + ".*[.-]spec\\.js".r, + ".*[.-]mock\\.js".r, + ".*[.-]e2e\\.js".r, + ".*[.-]test\\.js".r, + ".*cypress\\.json".r, + ".*test.*\\.json".r + ) + + private val IgnoredFilesRegex: Seq[Regex] = List( + ".*jest\\.config.*".r, + ".*webpack\\..*\\.js".r, + ".*vue\\.config\\.js".r, + ".*babel\\.config\\.js".r, + ".*chunk-vendors.*\\.js".r, // commonly found in webpack / vue.js projects + ".*app~.*\\.js".r, // commonly found in webpack / vue.js projects + ".*\\.chunk\\.js".r, + ".*\\.babelrc.*".r, + ".*\\.eslint.*".r, + ".*\\.tslint.*".r, + ".*\\.stylelintrc\\.js".r, + ".*rollup\\.config.*".r, + ".*\\.types\\.js".r, + ".*\\.cjs\\.js".r, + ".*eslint-local-rules\\.js".r, + ".*\\.devcontainer\\.json".r, + ".*Gruntfile\\.js".r, + ".*i18n.*\\.json".r ) - private val IgnoredFilesRegex: Seq[Regex] = List( - ".*jest\\.config.*".r, - ".*webpack\\..*\\.js".r, - ".*vue\\.config\\.js".r, - ".*babel\\.config\\.js".r, - ".*chunk-vendors.*\\.js".r, // commonly found in webpack / vue.js projects - ".*app~.*\\.js".r, // commonly found in webpack / vue.js projects - ".*\\.chunk\\.js".r, - ".*\\.babelrc.*".r, - ".*\\.eslint.*".r, - ".*\\.tslint.*".r, - ".*\\.stylelintrc\\.js".r, - ".*rollup\\.config.*".r, - ".*\\.types\\.js".r, - ".*\\.cjs\\.js".r, - ".*eslint-local-rules\\.js".r, - ".*\\.devcontainer\\.json".r, - ".*Gruntfile\\.js".r, - ".*i18n.*\\.json".r - ) - - case class AstGenRunnerResult( - parsedFiles: List[(String, String)] = List.empty, - skippedFiles: List[(String, String)] = List.empty - ) - - lazy private val executableName = Environment.operatingSystem match { - case Environment.OperatingSystemType.Windows => "astgen-win.exe" - case Environment.OperatingSystemType.Linux => "astgen-linux" - case Environment.OperatingSystemType.Mac => - Environment.architecture match { - case Environment.ArchitectureType.X86 => "astgen-macos" - case Environment.ArchitectureType.ARM => "astgen-macos-arm" - } - case Environment.OperatingSystemType.Unknown => - logger.warn("Could not detect OS version! Defaulting to 'Linux'.") - "astgen-linux" - } - - lazy private val executableDir: String = { - val dir = getClass.getProtectionDomain.getCodeSource.getLocation.toString - val indexOfLib = dir.lastIndexOf("lib") - val fixedDir = if (indexOfLib != -1) { - new java.io.File(dir.substring("file:".length, indexOfLib)).toString - } else { - val indexOfTarget = dir.lastIndexOf("target") - if (indexOfTarget != -1) { - new java.io.File(dir.substring("file:".length, indexOfTarget)).toString - } else { - "." - } - } - Paths.get(fixedDir, "/bin/astgen").toAbsolutePath.toString - } - - private def hasCompatibleAstGenVersion(astGenVersion: String): Boolean = { - ExternalCommand.run("astgen --version", ".").toOption.map(_.mkString.strip()) match { - case Some(installedVersion) - if installedVersion != "unknown" && - Try(VersionHelper.compare(installedVersion, astGenVersion)).toOption.getOrElse(-1) >= 0 => - logger.debug(s"Using local astgen v$installedVersion from systems PATH") - true - case Some(installedVersion) => - logger.debug( - s"Found local astgen v$installedVersion in systems PATH but jssrc2cpg requires at least v$astGenVersion" - ) - false - case _ => false - } - } - - private lazy val astGenCommand = { - val conf = ConfigFactory.load - val astGenVersion = conf.getString("jssrc2cpg.astgen_version") - if (hasCompatibleAstGenVersion(astGenVersion)) { - "astgen" - } else { - s"$executableDir/$executableName" - } - } -} - -class AstGenRunner(config: Config) { - - import AstGenRunner._ - - private val executableArgs = if (!config.tsTypes) " --no-tsTypes" else "" - - private def skippedFiles(in: File, astGenOut: List[String]): List[String] = { - val skipped = astGenOut.collect { - case out if !out.startsWith("Converted") && !out.startsWith("Retrieving") => - val filename = out.substring(0, out.indexOf(" ")) - val reason = out.substring(out.indexOf(" ") + 1) - logger.warn(s"\t- failed to parse '${in / filename}': '$reason'") - Option(filename) - case out => - logger.debug(s"\t+ $out") - None - } - skipped.flatten - } - - private def isIgnoredByUserConfig(filePath: String): Boolean = { - lazy val isInIgnoredFiles = config.ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath - } - lazy val isInIgnoredFileRegex = config.ignoredFilesRegex.matches(filePath) - if (isInIgnoredFiles || isInIgnoredFileRegex) { - logger.debug(s"'$filePath' ignored by user configuration") - true - } else { - false - } - } - - private def isMinifiedFile(filePath: String): Boolean = filePath match { - case p if MinifiedPathRegex.matches(p) => true - case p if File(p).exists && p.endsWith(".js") => - val lines = IOUtils.readLinesInFile(File(filePath).path) - val linesOfCode = lines.size - val longestLineLength = if (lines.isEmpty) 0 else lines.map(_.length).max - if (longestLineLength >= LineLengthThreshold && linesOfCode <= 50) { - logger.debug(s"'$filePath' seems to be a minified file (contains a line with length $longestLineLength)") - true - } else false - case _ => false - } - - private def isIgnoredByDefault(filePath: String): Boolean = { - lazy val isIgnored = IgnoredFilesRegex.exists(_.matches(filePath)) - lazy val isIgnoredTest = IgnoredTestsRegex.exists(_.matches(filePath)) - lazy val isMinified = isMinifiedFile(filePath) - if (isIgnored || isIgnoredTest || isMinified) { - logger.debug(s"'$filePath' ignored by default") - true - } else { - false - } - } - - private def isTranspiledFile(filePath: String): Boolean = { - val file = File(filePath) - // We ignore files iff: - // - they are *.js files and - // - they contain a //sourceMappingURL comment or have an associated source map file and - // - a file with the same name is located directly next to them - lazy val isJsFile = file.exists && file.extension.contains(".js") - lazy val hasSourceMapComment = IOUtils.readLinesInFile(file.path).exists(_.contains("//sourceMappingURL")) - lazy val hasSourceMapFile = File(s"$filePath.map").exists - lazy val hasSourceMap = hasSourceMapComment || hasSourceMapFile - lazy val hasFileWithSameName = - file.siblings.exists(_.nameWithoutExtension(includeAll = false) == file.nameWithoutExtension) - if (isJsFile && hasSourceMap && hasFileWithSameName) { - logger.debug(s"'$filePath' ignored by default (seems to be the result of transpilation)") - true - } else { - false - } - - } - - private def filterFiles(files: List[String], out: File): List[String] = { - files.filter { file => - file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match { - // We are not interested in JS / TS type definition files at this stage. - // TODO: maybe we can enable that later on and use the type definitions there - // for enhancing the CPG with additional type information for functions - case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false - case filePath if isIgnoredByUserConfig(filePath) => false - case filePath if isIgnoredByDefault(filePath) => false - case filePath if isTranspiledFile(filePath) => false - case _ => true - } - } - } - - /** Changes the file-extension by renaming this file; if file does not have an extension, it adds the extension. If - * file does not exist (or is a directory) no change is done and the current file is returned. - */ - private def changeExtensionTo(file: File, extension: String): File = { - val newName = s"${file.nameWithoutExtension(includeAll = false)}.${extension.stripPrefix(".")}" - if (file.isRegularFile) file.renameTo(newName) else if (file.notExists) File(newName) else file - } - - private def processEjsFiles(in: File, out: File, ejsFiles: List[String]): Try[Seq[String]] = { - val tmpJsFiles = ejsFiles.map { ejsFilePath => - val ejsFile = File(ejsFilePath) - val maybeTranspiledFile = File(s"${ejsFilePath.stripSuffix(".ejs")}.js") - if (isTranspiledFile(maybeTranspiledFile.pathAsString)) { - maybeTranspiledFile - } else { - val sourceFileContent = IOUtils.readEntireFile(ejsFile.path) - val preprocessContent = new EjsPreprocessor().preprocess(sourceFileContent) - (out / in.relativize(ejsFile).toString).parent.createDirectoryIfNotExists(createParents = true) - val newEjsFile = ejsFile.copyTo(out / in.relativize(ejsFile).toString) - val jsFile = changeExtensionTo(newEjsFile, ".js").writeText(preprocessContent) - newEjsFile.createFile().writeText(sourceFileContent) - jsFile - } - } - - val result = ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", out.toString()) - - val jsons = SourceFiles.determine(out.toString(), Set(".json")) - jsons.foreach { jsonPath => - val jsonFile = File(jsonPath) - val jsonContent = IOUtils.readEntireFile(jsonFile.path) - val json = ujson.read(jsonContent) - val fileName = json("fullName").str - val newFileName = fileName.patch(fileName.lastIndexOf(".js"), ".ejs", 3) - json("relativeName") = newFileName - json("fullName") = newFileName - jsonFile.writeText(json.toString()) - } - - tmpJsFiles.foreach(_.delete()) - result - } - - private def ejsFiles(in: File, out: File): Try[Seq[String]] = { - val files = SourceFiles.determine(in.pathAsString, Set(".ejs")) - if (files.nonEmpty) processEjsFiles(in, out, files) - else Success(Seq.empty) - } - - private def vueFiles(in: File, out: File): Try[Seq[String]] = { - val files = SourceFiles.determine(in.pathAsString, Set(".vue")) - if (files.nonEmpty) - ExternalCommand.run(s"$astGenCommand$executableArgs -t vue -o $out", in.toString()) - else Success(Seq.empty) - } - - private def jsFiles(in: File, out: File): Try[Seq[String]] = - ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", in.toString()) - - private def runAstGenNative(in: File, out: File): Try[Seq[String]] = for { - ejsResult <- ejsFiles(in, out) - vueResult <- vueFiles(in, out) - jsResult <- jsFiles(in, out) - } yield jsResult ++ vueResult ++ ejsResult - - def execute(out: File): AstGenRunnerResult = { - val in = File(config.inputPath) - logger.debug(s"Running astgen in '$in' ...") - runAstGenNative(in, out) match { - case Success(result) => - val parsed = filterFiles(SourceFiles.determine(out.toString(), Set(".json")), out) - val skipped = skippedFiles(in, result.toList) - AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) - case Failure(f) => - logger.debug("\t- running astgen failed!", f) - AstGenRunnerResult() - } - } - -} + case class AstGenRunnerResult( + parsedFiles: List[(String, String)] = List.empty, + skippedFiles: List[(String, String)] = List.empty + ) + + lazy private val executableName = Environment.operatingSystem match + case Environment.OperatingSystemType.Windows => "astgen-win.exe" + case Environment.OperatingSystemType.Linux => "astgen-linux" + case Environment.OperatingSystemType.Mac => + Environment.architecture match + case Environment.ArchitectureType.X86 => "astgen-macos" + case Environment.ArchitectureType.ARM => "astgen-macos-arm" + case Environment.OperatingSystemType.Unknown => + logger.warn("Could not detect OS version! Defaulting to 'Linux'.") + "astgen-linux" + + lazy private val executableDir: String = + val dir = getClass.getProtectionDomain.getCodeSource.getLocation.toString + val indexOfLib = dir.lastIndexOf("lib") + val fixedDir = if indexOfLib != -1 then + new java.io.File(dir.substring("file:".length, indexOfLib)).toString + else + val indexOfTarget = dir.lastIndexOf("target") + if indexOfTarget != -1 then + new java.io.File(dir.substring("file:".length, indexOfTarget)).toString + else + "." + Paths.get(fixedDir, "/bin/astgen").toAbsolutePath.toString + + private def hasCompatibleAstGenVersion(astGenVersion: String): Boolean = + ExternalCommand.run("astgen --version", ".").toOption.map(_.mkString.strip()) match + case Some(installedVersion) + if installedVersion != "unknown" && + Try(VersionHelper.compare(installedVersion, astGenVersion)).toOption.getOrElse( + -1 + ) >= 0 => + logger.debug(s"Using local astgen v$installedVersion from systems PATH") + true + case Some(installedVersion) => + logger.debug( + s"Found local astgen v$installedVersion in systems PATH but jssrc2cpg requires at least v$astGenVersion" + ) + false + case _ => false + + private lazy val astGenCommand = + val conf = ConfigFactory.load + val astGenVersion = conf.getString("jssrc2cpg.astgen_version") + if hasCompatibleAstGenVersion(astGenVersion) then + "astgen" + else + s"$executableDir/$executableName" +end AstGenRunner + +class AstGenRunner(config: Config): + + import AstGenRunner.* + + private val executableArgs = if !config.tsTypes then " --no-tsTypes" else "" + + private def skippedFiles(in: File, astGenOut: List[String]): List[String] = + val skipped = astGenOut.collect { + case out if !out.startsWith("Converted") && !out.startsWith("Retrieving") => + val filename = out.substring(0, out.indexOf(" ")) + val reason = out.substring(out.indexOf(" ") + 1) + logger.warn(s"\t- failed to parse '${in / filename}': '$reason'") + Option(filename) + case out => + logger.debug(s"\t+ $out") + None + } + skipped.flatten + + private def isIgnoredByUserConfig(filePath: String): Boolean = + lazy val isInIgnoredFiles = config.ignoredFiles.exists { + case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) + case ignorePath => filePath == ignorePath + } + lazy val isInIgnoredFileRegex = config.ignoredFilesRegex.matches(filePath) + if isInIgnoredFiles || isInIgnoredFileRegex then + logger.debug(s"'$filePath' ignored by user configuration") + true + else + false + + private def isMinifiedFile(filePath: String): Boolean = filePath match + case p if MinifiedPathRegex.matches(p) => true + case p if File(p).exists && p.endsWith(".js") => + val lines = IOUtils.readLinesInFile(File(filePath).path) + val linesOfCode = lines.size + val longestLineLength = if lines.isEmpty then 0 else lines.map(_.length).max + if longestLineLength >= LineLengthThreshold && linesOfCode <= 50 then + logger.debug( + s"'$filePath' seems to be a minified file (contains a line with length $longestLineLength)" + ) + true + else false + case _ => false + + private def isIgnoredByDefault(filePath: String): Boolean = + lazy val isIgnored = IgnoredFilesRegex.exists(_.matches(filePath)) + lazy val isIgnoredTest = IgnoredTestsRegex.exists(_.matches(filePath)) + lazy val isMinified = isMinifiedFile(filePath) + if isIgnored || isIgnoredTest || isMinified then + logger.debug(s"'$filePath' ignored by default") + true + else + false + + private def isTranspiledFile(filePath: String): Boolean = + val file = File(filePath) + // We ignore files iff: + // - they are *.js files and + // - they contain a //sourceMappingURL comment or have an associated source map file and + // - a file with the same name is located directly next to them + lazy val isJsFile = file.exists && file.extension.contains(".js") + lazy val hasSourceMapComment = + IOUtils.readLinesInFile(file.path).exists(_.contains("//sourceMappingURL")) + lazy val hasSourceMapFile = File(s"$filePath.map").exists + lazy val hasSourceMap = hasSourceMapComment || hasSourceMapFile + lazy val hasFileWithSameName = + file.siblings.exists(_.nameWithoutExtension(includeAll = + false + ) == file.nameWithoutExtension) + if isJsFile && hasSourceMap && hasFileWithSameName then + logger.debug( + s"'$filePath' ignored by default (seems to be the result of transpilation)" + ) + true + else + false + end isTranspiledFile + + private def filterFiles(files: List[String], out: File): List[String] = + files.filter { file => + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match + // We are not interested in JS / TS type definition files at this stage. + // TODO: maybe we can enable that later on and use the type definitions there + // for enhancing the CPG with additional type information for functions + case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false + case filePath if isIgnoredByUserConfig(filePath) => false + case filePath if isIgnoredByDefault(filePath) => false + case filePath if isTranspiledFile(filePath) => false + case _ => true + } + + /** Changes the file-extension by renaming this file; if file does not have an extension, it + * adds the extension. If file does not exist (or is a directory) no change is done and the + * current file is returned. + */ + private def changeExtensionTo(file: File, extension: String): File = + val newName = + s"${file.nameWithoutExtension(includeAll = false)}.${extension.stripPrefix(".")}" + if file.isRegularFile then file.renameTo(newName) + else if file.notExists then File(newName) + else file + + private def processEjsFiles(in: File, out: File, ejsFiles: List[String]): Try[Seq[String]] = + val tmpJsFiles = ejsFiles.map { ejsFilePath => + val ejsFile = File(ejsFilePath) + val maybeTranspiledFile = File(s"${ejsFilePath.stripSuffix(".ejs")}.js") + if isTranspiledFile(maybeTranspiledFile.pathAsString) then + maybeTranspiledFile + else + val sourceFileContent = IOUtils.readEntireFile(ejsFile.path) + val preprocessContent = new EjsPreprocessor().preprocess(sourceFileContent) + (out / in.relativize(ejsFile).toString).parent.createDirectoryIfNotExists( + createParents = true + ) + val newEjsFile = ejsFile.copyTo(out / in.relativize(ejsFile).toString) + val jsFile = changeExtensionTo(newEjsFile, ".js").writeText(preprocessContent) + newEjsFile.createFile().writeText(sourceFileContent) + jsFile + } + + val result = + ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", out.toString()) + + val jsons = SourceFiles.determine(out.toString(), Set(".json")) + jsons.foreach { jsonPath => + val jsonFile = File(jsonPath) + val jsonContent = IOUtils.readEntireFile(jsonFile.path) + val json = ujson.read(jsonContent) + val fileName = json("fullName").str + val newFileName = fileName.patch(fileName.lastIndexOf(".js"), ".ejs", 3) + json("relativeName") = newFileName + json("fullName") = newFileName + jsonFile.writeText(json.toString()) + } + + tmpJsFiles.foreach(_.delete()) + result + end processEjsFiles + + private def ejsFiles(in: File, out: File): Try[Seq[String]] = + val files = SourceFiles.determine(in.pathAsString, Set(".ejs")) + if files.nonEmpty then processEjsFiles(in, out, files) + else Success(Seq.empty) + + private def vueFiles(in: File, out: File): Try[Seq[String]] = + val files = SourceFiles.determine(in.pathAsString, Set(".vue")) + if files.nonEmpty then + ExternalCommand.run(s"$astGenCommand$executableArgs -t vue -o $out", in.toString()) + else Success(Seq.empty) + + private def jsFiles(in: File, out: File): Try[Seq[String]] = + ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", in.toString()) + + private def runAstGenNative(in: File, out: File): Try[Seq[String]] = + for + ejsResult <- ejsFiles(in, out) + vueResult <- vueFiles(in, out) + jsResult <- jsFiles(in, out) + yield jsResult ++ vueResult ++ ejsResult + + def execute(out: File): AstGenRunnerResult = + val in = File(config.inputPath) + logger.debug(s"Running astgen in '$in' ...") + runAstGenNative(in, out) match + case Success(result) => + val parsed = filterFiles(SourceFiles.determine(out.toString(), Set(".json")), out) + val skipped = skippedFiles(in, result.toList) + AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) + case Failure(f) => + logger.debug("\t- running astgen failed!", f) + AstGenRunnerResult() +end AstGenRunner diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala index 632524b2..c6d83559 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala @@ -8,93 +8,86 @@ import org.apache.commons.lang.StringUtils import scala.collection.concurrent.TrieMap import scala.util.Try -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.Failure import scala.util.Success -object PackageJsonParser { - private val logger = LoggerFactory.getLogger(PackageJsonParser.getClass) +object PackageJsonParser: + private val logger = LoggerFactory.getLogger(PackageJsonParser.getClass) - val PackageJsonFilename = "package.json" - val PackageJsonLockFilename = "package-lock.json" + val PackageJsonFilename = "package.json" + val PackageJsonLockFilename = "package-lock.json" - private val ProjectDependencies = - Seq("dependencies", "devDependencies", "peerDependencies", "optionalDependencies") + private val ProjectDependencies = + Seq("dependencies", "devDependencies", "peerDependencies", "optionalDependencies") - private val cachedDependencies: TrieMap[Path, Map[String, String]] = TrieMap.empty + private val cachedDependencies: TrieMap[Path, Map[String, String]] = TrieMap.empty - def isValidProjectPackageJson(packageJsonPath: Path): Boolean = { - if (packageJsonPath.toString.endsWith(PackageJsonParser.PackageJsonFilename)) { - val isNotEmpty = Try(IOUtils.readLinesInFile(packageJsonPath)) match { - case Success(content) => - content.forall(l => StringUtils.isNotBlank(StringUtils.normalizeSpace(l))) - case Failure(_) => false - } - isNotEmpty && dependencies(packageJsonPath).nonEmpty - } else { - false - } - } + def isValidProjectPackageJson(packageJsonPath: Path): Boolean = + if packageJsonPath.toString.endsWith(PackageJsonParser.PackageJsonFilename) then + val isNotEmpty = Try(IOUtils.readLinesInFile(packageJsonPath)) match + case Success(content) => + content.forall(l => StringUtils.isNotBlank(StringUtils.normalizeSpace(l))) + case Failure(_) => false + isNotEmpty && dependencies(packageJsonPath).nonEmpty + else + false - def dependencies(packageJsonPath: Path): Map[String, String] = - cachedDependencies.getOrElseUpdate( - packageJsonPath, { - val depsPath = packageJsonPath - val lockDepsPath = packageJsonPath.resolveSibling(Paths.get(PackageJsonLockFilename)) + def dependencies(packageJsonPath: Path): Map[String, String] = + cachedDependencies.getOrElseUpdate( + packageJsonPath, { + val depsPath = packageJsonPath + val lockDepsPath = packageJsonPath.resolveSibling(Paths.get(PackageJsonLockFilename)) - val lockDeps = Try { - val content = IOUtils.readEntireFile(lockDepsPath) - val objectMapper = new ObjectMapper - val packageJson = objectMapper.readTree(content) + val lockDeps = Try { + val content = IOUtils.readEntireFile(lockDepsPath) + val objectMapper = new ObjectMapper + val packageJson = objectMapper.readTree(content) - var depToVersion = Map.empty[String, String] - val dependencyIt = Option(packageJson.get("dependencies")) - .map(_.fields().asScala) - .getOrElse(Iterator.empty) - dependencyIt.foreach { entry => - val depName = entry.getKey - val versionNode = entry.getValue.get("version") - if (versionNode != null) { - depToVersion = depToVersion.updated(depName, versionNode.asText()) - } - } - depToVersion - }.toOption + var depToVersion = Map.empty[String, String] + val dependencyIt = Option(packageJson.get("dependencies")) + .map(_.fields().asScala) + .getOrElse(Iterator.empty) + dependencyIt.foreach { entry => + val depName = entry.getKey + val versionNode = entry.getValue.get("version") + if versionNode != null then + depToVersion = depToVersion.updated(depName, versionNode.asText()) + } + depToVersion + }.toOption - // lazy val because we only evaluate this in case no package lock file is available. - lazy val deps = Try { - val content = IOUtils.readEntireFile(depsPath) - val objectMapper = new ObjectMapper - val packageJson = objectMapper.readTree(content) + // lazy val because we only evaluate this in case no package lock file is available. + lazy val deps = Try { + val content = IOUtils.readEntireFile(depsPath) + val objectMapper = new ObjectMapper + val packageJson = objectMapper.readTree(content) - var depToVersion = Map.empty[String, String] - ProjectDependencies - .foreach { dependency => - val dependencyIt = Option(packageJson.get(dependency)) - .map(_.fields().asScala) - .getOrElse(Iterator.empty) - dependencyIt.foreach { entry => - depToVersion = depToVersion.updated(entry.getKey, entry.getValue.asText()) - } - } - depToVersion - }.toOption + var depToVersion = Map.empty[String, String] + ProjectDependencies + .foreach { dependency => + val dependencyIt = Option(packageJson.get(dependency)) + .map(_.fields().asScala) + .getOrElse(Iterator.empty) + dependencyIt.foreach { entry => + depToVersion = + depToVersion.updated(entry.getKey, entry.getValue.asText()) + } + } + depToVersion + }.toOption - if (lockDeps.isDefined && lockDeps.get.nonEmpty) { - logger.debug(s"Loaded dependencies from '$lockDepsPath'.") - lockDeps.get - } else { - if (deps.isDefined && deps.get.nonEmpty) { - logger.debug(s"Loaded dependencies from '$depsPath'.") - deps.get - } else { - logger.debug( - s"No project dependencies found in $PackageJsonFilename or $PackageJsonLockFilename at '${depsPath.getParent}'." - ) - Map.empty + if lockDeps.isDefined && lockDeps.get.nonEmpty then + logger.debug(s"Loaded dependencies from '$lockDepsPath'.") + lockDeps.get + else if deps.isDefined && deps.get.nonEmpty then + logger.debug(s"Loaded dependencies from '$depsPath'.") + deps.get + else + logger.debug( + s"No project dependencies found in $PackageJsonFilename or $PackageJsonLockFilename at '${depsPath.getParent}'." + ) + Map.empty } - } - } - ) - -} + ) +end PackageJsonParser diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala index d97e9654..ef3000c9 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala @@ -1,9 +1,7 @@ package io.appthreat.pysrc2cpg -class AutoIncIndex(private var index: Int) { - def getAndInc: Int = { - val ret = index - index += 1 - ret - } -} +class AutoIncIndex(private var index: Int): + def getAndInc: Int = + val ret = index + index += 1 + ret diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala index 606130d4..b5d4ebad 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala @@ -7,32 +7,33 @@ import io.appthreat.pythonparser.PyParser import io.appthreat.x2cpg.ValidationMode import org.slf4j.LoggerFactory -class CodeToCpg(cpg: Cpg, inputProvider: Iterable[InputProvider], schemaValidationMode: ValidationMode) - extends ConcurrentWriterCpgPass[InputProvider](cpg) { - import CodeToCpg.logger +class CodeToCpg( + cpg: Cpg, + inputProvider: Iterable[InputProvider], + schemaValidationMode: ValidationMode +) extends ConcurrentWriterCpgPass[InputProvider](cpg): + import CodeToCpg.logger - override def generateParts(): Array[InputProvider] = inputProvider.toArray + override def generateParts(): Array[InputProvider] = inputProvider.toArray - override def runOnPart(diffGraph: DiffGraphBuilder, inputProvider: InputProvider): Unit = { - val inputPair = inputProvider() - try { - val parser = new PyParser() - val lineBreakCorrectedCode = inputPair.content.replace("\r\n", "\n").replace("\r", "\n") - val astRoot = parser.parse(lineBreakCorrectedCode) - val nodeToCode = new NodeToCode(lineBreakCorrectedCode) - val astVisitor = new PythonAstVisitor(inputPair.relFileName, nodeToCode, PythonV2AndV3)(schemaValidationMode) - astVisitor.convert(astRoot) + override def runOnPart(diffGraph: DiffGraphBuilder, inputProvider: InputProvider): Unit = + val inputPair = inputProvider() + try + val parser = new PyParser() + val lineBreakCorrectedCode = inputPair.content.replace("\r\n", "\n").replace("\r", "\n") + val astRoot = parser.parse(lineBreakCorrectedCode) + val nodeToCode = new NodeToCode(lineBreakCorrectedCode) + val astVisitor = new PythonAstVisitor(inputPair.relFileName, nodeToCode, PythonV2AndV3)( + schemaValidationMode + ) + astVisitor.convert(astRoot) - diffGraph.absorb(astVisitor.getDiffGraph) - } catch { - case exception: Throwable => - logger.debug(s"Failed to convert file ${inputPair.relFileName}", exception) - Iterator.empty - } - } + diffGraph.absorb(astVisitor.getDiffGraph) + catch + case exception: Throwable => + logger.debug(s"Failed to convert file ${inputPair.relFileName}", exception) + Iterator.empty +end CodeToCpg -} - -object CodeToCpg { - private val logger = LoggerFactory.getLogger(getClass) -} +object CodeToCpg: + private val logger = LoggerFactory.getLogger(getClass) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala index edddbbe3..8d2aec2c 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala @@ -5,22 +5,20 @@ import io.appthreat.x2cpg.passes.frontend.XConfigFileCreationPass import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg, requirementsTxt: String = "requirements.txt") - extends XConfigFileCreationPass(cpg) { + extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // TOML files - extensionFilter(".toml"), - // INI files - extensionFilter(".ini"), - // YAML files - extensionFilter(".yaml"), - extensionFilter(".lock"), - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json"), - pathEndFilter("setup.cfg"), - // Requirements.txt - pathEndFilter(requirementsTxt) - ) - -} + override val configFileFilters: List[File => Boolean] = List( + // TOML files + extensionFilter(".toml"), + // INI files + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".lock"), + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json"), + pathEndFilter("setup.cfg"), + // Requirements.txt + pathEndFilter(requirementsTxt) + ) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala index c62bd40b..d84fbf30 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala @@ -1,6 +1,5 @@ package io.appthreat.pysrc2cpg -object Constants { - val ANY = "ANY" - val GLOBAL_NAMESPACE = "" -} +object Constants: + val ANY = "ANY" + val GLOBAL_NAMESPACE = "" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala index 3c3f3119..b2668f19 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala @@ -9,424 +9,417 @@ import org.slf4j.LoggerFactory import scala.collection.mutable -object ContextStack { - private val logger = LoggerFactory.getLogger(getClass) - - def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = { - src.lineNumber match { - // If there are multiple occurrences and the local is already set, ignore later updates - case Some(srcLineNo) if tgt.lineNumber.isEmpty || !tgt.lineNumber.exists(_ < srcLineNo) => - tgt.lineNumber(src.lineNumber) - tgt.columnNumber(src.columnNumber) - case _ => - } - } -} - -class ContextStack { - import ContextStack.logger - - private trait Context { - val astParent: nodes.NewNode - val order: AutoIncIndex - val variables: mutable.Map[String, nodes.NewNode] - var lambdaCounter: Int - } - - private class MethodContext( - val scopeName: Option[String], - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val isClassBodyMethod: Boolean = false, - val methodBlockNode: Option[nodes.NewBlock] = None, - val methodRefNode: Option[nodes.NewMethodRef] = None, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - val globalVariables: mutable.Set[String] = mutable.Set.empty, - val nonLocalVariables: mutable.Set[String] = mutable.Set.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - private class ClassContext( - val scopeName: Option[String], - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - // Used to represent comprehension variable and exception - // handler context. - // E.g.: [x for x in y] creates an extra context with a - // local variable x which is different from a possible x - // in the surrounding method context. The same applies - // to x in: - // try: - // pass - // except e as x: - // pass - private class SpecialBlockContext( - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - private case class VariableReference( - identifier: nodes.NewIdentifier, - memOp: MemoryOperation, - // Context stack as it was when VariableReference - // was created. Context objects are and need to - // shared between different VariableReference - // instances because the changes in the variable - // maps need to be in sync. - stack: List[Context] - ) - - private var stack = List[Context]() - private val variableReferences = mutable.ArrayBuffer.empty[VariableReference] - private var moduleMethodContext = Option.empty[MethodContext] - private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock] - private val fileNamespaceBlockOrder = new AutoIncIndex(1) - - private def push(context: Context): Unit = { - stack = context :: stack - } - - def pushMethod( - scopeName: Option[String], - methodNode: nodes.NewMethod, - methodBlockNode: nodes.NewBlock, - methodRefNode: Option[nodes.NewMethodRef] - ): Unit = { - val isClassBodyMethod = stack.headOption.exists(_.isInstanceOf[ClassContext]) - - val methodContext = - new MethodContext( - scopeName, - methodNode, - new AutoIncIndex(1), - isClassBodyMethod, - Some(methodBlockNode), - methodRefNode - ) - if (moduleMethodContext.isEmpty) { - moduleMethodContext = Some(methodContext) - } - push(methodContext) - } - - def pushClass(scopeName: Option[String], classNode: nodes.NewTypeDecl): Unit = { - push(new ClassContext(scopeName, classNode, new AutoIncIndex(1))) - } - - def pushSpecialContext(): Unit = { - val methodContext = findEnclosingMethodContext(stack) - push(new SpecialBlockContext(methodContext.astParent, methodContext.order)) - } - - def pop(): Unit = { - stack = stack.tail - } - - def setFileNamespaceBlock(namespaceBlock: nodes.NewNamespaceBlock): Unit = { - fileNamespaceBlock = Some(namespaceBlock) - } - - def addVariableReference(identifier: nodes.NewIdentifier, memOp: MemoryOperation): Unit = { - variableReferences.append(VariableReference(identifier, memOp, stack)) - } - - def getAndIncLambdaCounter(): Int = { - val result = stack.head.lambdaCounter - stack.head.lambdaCounter += 1 - result - } - - private def findEnclosingMethodContext(contextStack: List[Context]): MethodContext = { - contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext] - } - - def findEnclosingTypeDecl(): Option[NewNode] = { - stack.find(_.isInstanceOf[ClassContext]) match { - case Some(classContext: ClassContext) => - Some(classContext.astParent) - case _ => None - } - } - - def createIdentifierLinks( - createLocal: (String, Option[String]) => nodes.NewLocal, - createClosureBinding: (String, String) => nodes.NewClosureBinding, - createAstEdge: (nodes.NewNode, nodes.NewNode, Int) => Unit, - createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, - createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit - ): Unit = { - // Before we do any linking, we iterate over all variable references and - // create a variable in the module method context for each global variable - // with a store operation on it. - // This is necessary because there might be load/delete operations - // referencing the global variable which are syntactically before the store - // operations. - variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) => - val name = identifier.name - if ( - memOp == Store && - findEnclosingMethodContext(contextStack).globalVariables.contains(name) && - !moduleMethodContext.get.variables.contains(name) - ) { - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - createAstEdge(localNode, moduleMethodContext.get.methodBlockNode.get, moduleMethodContext.get.order.getAndInc) - moduleMethodContext.get.variables.put(name, localNode) - } - } - - // Variable references processing needs to be ordered by context depth in - // order to make sure that variables captured into deeper nested contexts - // are already created. - val sortedVariableRefs = variableReferences.sortBy(_.stack.size) - sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) => - val name = identifier.name - // Store and delete operations look up variable only in method scope. - // Load operations also look up captured or global variables. - // If a store and load/del happens in the same context, the store must - // come first. Otherwise it is not valid Python, which we assume here. - if (memOp == Load) { - linkLocalOrCapturing( - createLocal, - createClosureBinding, - createAstEdge, - createRefEdge, - createCaptureEdge, - identifier, - name, - contextStack - ) - } else { - val enclosingMethodContext = findEnclosingMethodContext(contextStack) - - if ( - enclosingMethodContext.globalVariables.contains(name) || - enclosingMethodContext.nonLocalVariables.contains(name) - ) { - linkLocalOrCapturing( - createLocal, - createClosureBinding, - createAstEdge, - createRefEdge, - createCaptureEdge, - identifier, - name, - contextStack - ) - } else if (memOp == Store) { - var variableNode = lookupVariableInMethod(name, contextStack) - if (variableNode.isEmpty) { - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - val enclosingMethodContext = findEnclosingMethodContext(contextStack) - createAstEdge(localNode, enclosingMethodContext.methodBlockNode.get, enclosingMethodContext.order.getAndInc) - enclosingMethodContext.variables.put(name, localNode) - variableNode = Some(localNode) - } - createRefEdge(variableNode.get, identifier) - } else if (memOp == Del) { - val variableNode = lookupVariableInMethod(name, contextStack) - variableNode match { - case Some(variableNode) => - createRefEdge(variableNode, identifier) - case None => - // When we could not find a matching variable we get here and create a local in - // the method context so that we can link something and fullfil the CPG - // format requirements. - // For example this happens when there are wildcard imports directly into the - // modules namespace. - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - val methodContext = findEnclosingMethodContext(contextStack) - createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) - methodContext.variables.put(name, localNode) - createRefEdge(localNode, identifier) - } - } - } - } - } - - /** Assignments to variables on the module-level may be exported to other modules and behave as inter-procedurally - * global variables. - * @param lhs - * the LHS node of an assignment - */ - def considerAsGlobalVariable(lhs: NewNode): Unit = { - lhs match { - case n: NewIdentifier if findEnclosingMethodContext(stack).scopeName.contains("") => - addGlobalVariable(n.name) - case _ => - } - } - - /** For module-methods, the variables of this method can be imported into other modules which resembles behaviour much - * like fields/members. This inter-procedural accessibility should be marked via the module's type decl node. - */ - def createMemberLinks(moduleTypeDecl: NewTypeDecl, astEdgeLinker: (NewNode, NewNode, Int) => Unit): Unit = { - val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables - variableReferences - .map(_.identifier) - .filter(i => globalVarsForEnclMethod.contains(i.name)) - .sortBy(i => (i.lineNumber, i.columnNumber)) - .distinctBy(_.name) - .map(i => - NewMember() - .name(i.name) - .typeFullName(Constants.ANY) - .dynamicTypeHintFullName(i.dynamicTypeHintFullName) - .lineNumber(i.lineNumber) - .columnNumber(i.columnNumber) - .code(i.name) - ) - .zipWithIndex - .foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) } - } - - private def linkLocalOrCapturing( - createLocal: (String, Option[String]) => NewLocal, - createClosureBinding: (String, String) => NewClosureBinding, - createAstEdge: (NewNode, NewNode, Int) => Unit, - createRefEdge: (NewNode, NewNode) => Unit, - createCaptureEdge: (NewNode, NewNode) => Unit, - identifier: NewIdentifier, - name: String, - contextStack: List[Context] - ): Unit = { - var identifierOrClosureBindingToLink: nodes.NewNode = identifier - val stackIt = contextStack.iterator - var contextHasVariable = false - val startContext = contextStack.head - while (stackIt.hasNext && !contextHasVariable) { - val context = stackIt.next() - - context match { - case methodContext: MethodContext => - // Context is only relevant for linking if it is not a class body methods context - // or the identifier/reference itself is from the class body method context. - if (!methodContext.isClassBodyMethod || methodContext == startContext) { - contextHasVariable = context.variables.contains(name) - - val closureBindingId = - methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name - - if (!contextHasVariable) { - if (context != moduleMethodContext.get) { - val localNode = createLocal(name, Some(closureBindingId)) - transferLineColInfo(identifier, localNode) - createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) - methodContext.variables.put(name, localNode) - } else { - // When we could not even find a matching variable in the module context we get - // here and create a local so that we can link something and fullfil the CPG - // format requirements. - // For example this happens when there are wildcard imports directly into the - // modules namespace. +object ContextStack: + private val logger = LoggerFactory.getLogger(getClass) + + def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = + src.lineNumber match + // If there are multiple occurrences and the local is already set, ignore later updates + case Some(srcLineNo) + if tgt.lineNumber.isEmpty || !tgt.lineNumber.exists(_ < srcLineNo) => + tgt.lineNumber(src.lineNumber) + tgt.columnNumber(src.columnNumber) + case _ => + +class ContextStack: + import ContextStack.logger + + private trait Context: + val astParent: nodes.NewNode + val order: AutoIncIndex + val variables: mutable.Map[String, nodes.NewNode] + var lambdaCounter: Int + + private class MethodContext( + val scopeName: Option[String], + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val isClassBodyMethod: Boolean = false, + val methodBlockNode: Option[nodes.NewBlock] = None, + val methodRefNode: Option[nodes.NewMethodRef] = None, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + val globalVariables: mutable.Set[String] = mutable.Set.empty, + val nonLocalVariables: mutable.Set[String] = mutable.Set.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + private class ClassContext( + val scopeName: Option[String], + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + // Used to represent comprehension variable and exception + // handler context. + // E.g.: [x for x in y] creates an extra context with a + // local variable x which is different from a possible x + // in the surrounding method context. The same applies + // to x in: + // try: + // pass + // except e as x: + // pass + private class SpecialBlockContext( + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + private case class VariableReference( + identifier: nodes.NewIdentifier, + memOp: MemoryOperation, + // Context stack as it was when VariableReference + // was created. Context objects are and need to + // shared between different VariableReference + // instances because the changes in the variable + // maps need to be in sync. + stack: List[Context] + ) + + private var stack = List[Context]() + private val variableReferences = mutable.ArrayBuffer.empty[VariableReference] + private var moduleMethodContext = Option.empty[MethodContext] + private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock] + private val fileNamespaceBlockOrder = new AutoIncIndex(1) + + private def push(context: Context): Unit = + stack = context :: stack + + def pushMethod( + scopeName: Option[String], + methodNode: nodes.NewMethod, + methodBlockNode: nodes.NewBlock, + methodRefNode: Option[nodes.NewMethodRef] + ): Unit = + val isClassBodyMethod = stack.headOption.exists(_.isInstanceOf[ClassContext]) + + val methodContext = + new MethodContext( + scopeName, + methodNode, + new AutoIncIndex(1), + isClassBodyMethod, + Some(methodBlockNode), + methodRefNode + ) + if moduleMethodContext.isEmpty then + moduleMethodContext = Some(methodContext) + push(methodContext) + end pushMethod + + def pushClass(scopeName: Option[String], classNode: nodes.NewTypeDecl): Unit = + push(new ClassContext(scopeName, classNode, new AutoIncIndex(1))) + + def pushSpecialContext(): Unit = + val methodContext = findEnclosingMethodContext(stack) + push(new SpecialBlockContext(methodContext.astParent, methodContext.order)) + + def pop(): Unit = + stack = stack.tail + + def setFileNamespaceBlock(namespaceBlock: nodes.NewNamespaceBlock): Unit = + fileNamespaceBlock = Some(namespaceBlock) + + def addVariableReference(identifier: nodes.NewIdentifier, memOp: MemoryOperation): Unit = + variableReferences.append(VariableReference(identifier, memOp, stack)) + + def getAndIncLambdaCounter(): Int = + val result = stack.head.lambdaCounter + stack.head.lambdaCounter += 1 + result + + private def findEnclosingMethodContext(contextStack: List[Context]): MethodContext = + contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext] + + def findEnclosingTypeDecl(): Option[NewNode] = + stack.find(_.isInstanceOf[ClassContext]) match + case Some(classContext: ClassContext) => + Some(classContext.astParent) + case _ => None + + def createIdentifierLinks( + createLocal: (String, Option[String]) => nodes.NewLocal, + createClosureBinding: (String, String) => nodes.NewClosureBinding, + createAstEdge: (nodes.NewNode, nodes.NewNode, Int) => Unit, + createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, + createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit + ): Unit = + // Before we do any linking, we iterate over all variable references and + // create a variable in the module method context for each global variable + // with a store operation on it. + // This is necessary because there might be load/delete operations + // referencing the global variable which are syntactically before the store + // operations. + variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) => + val name = identifier.name + if + memOp == Store && + findEnclosingMethodContext(contextStack).globalVariables.contains(name) && + !moduleMethodContext.get.variables.contains(name) + then val localNode = createLocal(name, None) transferLineColInfo(identifier, localNode) - createAstEdge(localNode, methodContext.methodBlockNode.get, methodContext.order.getAndInc) - methodContext.variables.put(name, localNode) - } - } - val localNodeInContext = methodContext.variables(name) - - createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + createAstEdge( + localNode, + moduleMethodContext.get.methodBlockNode.get, + moduleMethodContext.get.order.getAndInc + ) + moduleMethodContext.get.variables.put(name, localNode) + } - if (!contextHasVariable && context != moduleMethodContext.get) { - identifierOrClosureBindingToLink = createClosureBinding(closureBindingId, name) - createCaptureEdge(identifierOrClosureBindingToLink, methodContext.methodRefNode.get) + // Variable references processing needs to be ordered by context depth in + // order to make sure that variables captured into deeper nested contexts + // are already created. + val sortedVariableRefs = variableReferences.sortBy(_.stack.size) + sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) => + val name = identifier.name + // Store and delete operations look up variable only in method scope. + // Load operations also look up captured or global variables. + // If a store and load/del happens in the same context, the store must + // come first. Otherwise it is not valid Python, which we assume here. + if memOp == Load then + linkLocalOrCapturing( + createLocal, + createClosureBinding, + createAstEdge, + createRefEdge, + createCaptureEdge, + identifier, + name, + contextStack + ) + else + val enclosingMethodContext = findEnclosingMethodContext(contextStack) + + if + enclosingMethodContext.globalVariables.contains(name) || + enclosingMethodContext.nonLocalVariables.contains(name) + then + linkLocalOrCapturing( + createLocal, + createClosureBinding, + createAstEdge, + createRefEdge, + createCaptureEdge, + identifier, + name, + contextStack + ) + else if memOp == Store then + var variableNode = lookupVariableInMethod(name, contextStack) + if variableNode.isEmpty then + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + val enclosingMethodContext = findEnclosingMethodContext(contextStack) + createAstEdge( + localNode, + enclosingMethodContext.methodBlockNode.get, + enclosingMethodContext.order.getAndInc + ) + enclosingMethodContext.variables.put(name, localNode) + variableNode = Some(localNode) + createRefEdge(variableNode.get, identifier) + else if memOp == Del then + val variableNode = lookupVariableInMethod(name, contextStack) + variableNode match + case Some(variableNode) => + createRefEdge(variableNode, identifier) + case None => + // When we could not find a matching variable we get here and create a local in + // the method context so that we can link something and fullfil the CPG + // format requirements. + // For example this happens when there are wildcard imports directly into the + // modules namespace. + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + val methodContext = findEnclosingMethodContext(contextStack) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + createRefEdge(localNode, identifier) + end if + end if + } + end createIdentifierLinks + + /** Assignments to variables on the module-level may be exported to other modules and behave as + * inter-procedurally global variables. + * @param lhs + * the LHS node of an assignment + */ + def considerAsGlobalVariable(lhs: NewNode): Unit = + lhs match + case n: NewIdentifier + if findEnclosingMethodContext(stack).scopeName.contains("") => + addGlobalVariable(n.name) + case _ => + + /** For module-methods, the variables of this method can be imported into other modules which + * resembles behaviour much like fields/members. This inter-procedural accessibility should be + * marked via the module's type decl node. + */ + def createMemberLinks( + moduleTypeDecl: NewTypeDecl, + astEdgeLinker: (NewNode, NewNode, Int) => Unit + ): Unit = + val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables + variableReferences + .map(_.identifier) + .filter(i => globalVarsForEnclMethod.contains(i.name)) + .sortBy(i => (i.lineNumber, i.columnNumber)) + .distinctBy(_.name) + .map(i => + NewMember() + .name(i.name) + .typeFullName(Constants.ANY) + .dynamicTypeHintFullName(i.dynamicTypeHintFullName) + .lineNumber(i.lineNumber) + .columnNumber(i.columnNumber) + .code(i.name) + ) + .zipWithIndex + .foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) } + end createMemberLinks + + private def linkLocalOrCapturing( + createLocal: (String, Option[String]) => NewLocal, + createClosureBinding: (String, String) => NewClosureBinding, + createAstEdge: (NewNode, NewNode, Int) => Unit, + createRefEdge: (NewNode, NewNode) => Unit, + createCaptureEdge: (NewNode, NewNode) => Unit, + identifier: NewIdentifier, + name: String, + contextStack: List[Context] + ): Unit = + var identifierOrClosureBindingToLink: nodes.NewNode = identifier + val stackIt = contextStack.iterator + var contextHasVariable = false + val startContext = contextStack.head + while stackIt.hasNext && !contextHasVariable do + val context = stackIt.next() + + context match + case methodContext: MethodContext => + // Context is only relevant for linking if it is not a class body methods context + // or the identifier/reference itself is from the class body method context. + if !methodContext.isClassBodyMethod || methodContext == startContext then + contextHasVariable = context.variables.contains(name) + + val closureBindingId = + methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name + + if !contextHasVariable then + if context != moduleMethodContext.get then + val localNode = createLocal(name, Some(closureBindingId)) + transferLineColInfo(identifier, localNode) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + else + // When we could not even find a matching variable in the module context we get + // here and create a local so that we can link something and fullfil the CPG + // format requirements. + // For example this happens when there are wildcard imports directly into the + // modules namespace. + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + end if + val localNodeInContext = methodContext.variables(name) + + createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + + if !contextHasVariable && context != moduleMethodContext.get then + identifierOrClosureBindingToLink = + createClosureBinding(closureBindingId, name) + createCaptureEdge( + identifierOrClosureBindingToLink, + methodContext.methodRefNode.get + ) + case specialBlockContext: SpecialBlockContext => + contextHasVariable = context.variables.contains(name) + if contextHasVariable then + val localNodeInContext = specialBlockContext.variables(name) + createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + + case _: ClassContext => + assert(context.variables.isEmpty) + // Class context is not relevant for variable linking. + // The context relevant for this is the class method body context. + end match + end while + end linkLocalOrCapturing + + private def lookupVariableInMethod(name: String, stack: List[Context]): Option[nodes.NewNode] = + var variableNode = Option.empty[nodes.NewNode] + + val stackIt = stack.iterator + var lastContextWasMethod = false + while stackIt.hasNext && variableNode.isEmpty && !lastContextWasMethod do + val context = stackIt.next() + variableNode = context.variables.get(name) + lastContextWasMethod = context.isInstanceOf[MethodContext] + variableNode + + def addParameter(parameter: nodes.NewMethodParameterIn): Unit = + assert(stack.head.isInstanceOf[MethodContext]) + stack.head.variables.put(parameter.name, parameter) + + def addSpecialVariable(local: nodes.NewLocal): Unit = + assert(stack.head.isInstanceOf[SpecialBlockContext]) + stack.head.variables.put(local.name, local) + + def addGlobalVariable(name: String): Unit = + findEnclosingMethodContext(stack).globalVariables.add(name) + + def addNonLocalVariable(name: String): Unit = + findEnclosingMethodContext(stack).nonLocalVariables.add(name) + + // Together with the file name this is used to compute full names. + def qualName: String = + stack + .flatMap { + case methodContext: MethodContext => + methodContext.scopeName + case specialBlockContext: SpecialBlockContext => + None + case classContext: ClassContext => + classContext.scopeName } - } - case specialBlockContext: SpecialBlockContext => - contextHasVariable = context.variables.contains(name) - if (contextHasVariable) { - val localNodeInContext = specialBlockContext.variables(name) - createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) - } - - case _: ClassContext => - assert(context.variables.isEmpty) - // Class context is not relevant for variable linking. - // The context relevant for this is the class method body context. - } - } - } - - private def lookupVariableInMethod(name: String, stack: List[Context]): Option[nodes.NewNode] = { - var variableNode = Option.empty[nodes.NewNode] - - val stackIt = stack.iterator - var lastContextWasMethod = false - while (stackIt.hasNext && variableNode.isEmpty && !lastContextWasMethod) { - val context = stackIt.next() - variableNode = context.variables.get(name) - lastContextWasMethod = context.isInstanceOf[MethodContext] - } - variableNode - } - - def addParameter(parameter: nodes.NewMethodParameterIn): Unit = { - assert(stack.head.isInstanceOf[MethodContext]) - stack.head.variables.put(parameter.name, parameter) - } - - def addSpecialVariable(local: nodes.NewLocal): Unit = { - assert(stack.head.isInstanceOf[SpecialBlockContext]) - stack.head.variables.put(local.name, local) - } - - def addGlobalVariable(name: String): Unit = { - findEnclosingMethodContext(stack).globalVariables.add(name) - } - - def addNonLocalVariable(name: String): Unit = { - findEnclosingMethodContext(stack).nonLocalVariables.add(name) - } - - // Together with the file name this is used to compute full names. - def qualName: String = { - stack - .flatMap { - case methodContext: MethodContext => - methodContext.scopeName - case specialBlockContext: SpecialBlockContext => - None - case classContext: ClassContext => - classContext.scopeName - } - .reverse - .mkString(".") - } - - def astParent: nodes.NewNode = { - stack match { - case head :: _ => - head.astParent - case Nil => - fileNamespaceBlock.get - } - } - - def order: AutoIncIndex = { - stack match { - case head :: _ => - head.order - case Nil => - fileNamespaceBlockOrder - } - } - - def isClassContext: Boolean = { - stack.nonEmpty && (stack.head match { - case methodContext: MethodContext if methodContext.isClassBodyMethod => true - case _ => false - }) - } - -} + .reverse + .mkString(".") + + def astParent: nodes.NewNode = + stack match + case head :: _ => + head.astParent + case Nil => + fileNamespaceBlock.get + + def order: AutoIncIndex = + stack match + case head :: _ => + head.order + case Nil => + fileNamespaceBlockOrder + + def isClassContext: Boolean = + stack.nonEmpty && (stack.head match + case methodContext: MethodContext if methodContext.isClassBodyMethod => true + case _ => false + ) +end ContextStack diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala index 2703e8c5..89ac6863 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala @@ -2,7 +2,7 @@ package io.appthreat.pysrc2cpg import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.{NewDependency} import org.slf4j.{Logger, LoggerFactory} @@ -20,20 +20,19 @@ MarkupSafe==1.1.1 Werkzeug==1.0.1 ``` */ -class DependenciesFromRequirementsTxtPass(cpg: Cpg) extends CpgPass(cpg) { - private val logger: Logger = LoggerFactory.getLogger(classOf[DependenciesFromRequirementsTxtPass]) - override def run(dstGraph: DiffGraphBuilder): Unit = { - cpg.configFile.filter(_.name.endsWith("requirements.txt")).foreach { node => - val lines = node.content.split("\n") - lines.filter(_.matches("^[^=]+==[^=]+$")).foreach { line => - val keyValPattern: Regex = "^([^=]+)==([^=]+)$".r - for (patternMatch <- keyValPattern.findAllMatchIn(line)) { - val name = patternMatch.group(1) - val version = patternMatch.group(2) - val node = NewDependency().name(name).version(version).dependencyGroupId(name) - dstGraph.addNode(node) +class DependenciesFromRequirementsTxtPass(cpg: Cpg) extends CpgPass(cpg): + private val logger: Logger = + LoggerFactory.getLogger(classOf[DependenciesFromRequirementsTxtPass]) + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.configFile.filter(_.name.endsWith("requirements.txt")).foreach { node => + val lines = node.content.split("\n") + lines.filter(_.matches("^[^=]+==[^=]+$")).foreach { line => + val keyValPattern: Regex = "^([^=]+)==([^=]+)$".r + for patternMatch <- keyValPattern.findAllMatchIn(line) do + val name = patternMatch.group(1) + val version = patternMatch.group(2) + val node = NewDependency().name(name).version(version).dependencyGroupId(name) + dstGraph.addNode(node) + } } - } - } - } -} +end DependenciesFromRequirementsTxtPass diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala index b2b266b6..b0e848c0 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala @@ -3,99 +3,118 @@ package io.appthreat.pysrc2cpg import io.appthreat.x2cpg.passes.frontend.ImportStringHandling import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames -import io.shiftleft.codepropertygraph.generated.nodes.{CfgNode, MethodParameterIn, MethodReturn, StoredNode} +import io.shiftleft.codepropertygraph.generated.nodes.{ + CfgNode, + MethodParameterIn, + MethodReturn, + StoredNode +} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate import java.io.File import java.util.regex.{Matcher, Pattern} -/** The type hints we pick up via the parser are not full names. This pass fixes that by retrieving the import for each - * dynamic type hint and adjusting the dynamic type hint full name field accordingly. +/** The type hints we pick up via the parser are not full names. This pass fixes that by retrieving + * the import for each dynamic type hint and adjusting the dynamic type hint full name field + * accordingly. */ -class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgNode](cpg) { +class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgNode](cpg): - private case class ImportScope(entity: Option[String], alias: Option[String]) + private case class ImportScope(entity: Option[String], alias: Option[String]) - private val fileToImports = cpg.imports.l - .flatMap(imp => imp.call.file.l.map { f => f.name -> imp }) - .groupBy(_._1) - .view - .mapValues(_.map { case (_, imp) => - ImportScope(imp.importedEntity, imp.importedAs) - }) + private val fileToImports = cpg.imports.l + .flatMap(imp => imp.call.file.l.map { f => f.name -> imp }) + .groupBy(_._1) + .view + .mapValues(_.map { case (_, imp) => + ImportScope(imp.importedEntity, imp.importedAs) + }) - override def generateParts(): Array[CfgNode] = - (cpg.methodReturn.filter(x => x.typeFullName != Constants.ANY) ++ cpg.parameter.filter(x => - x.typeFullName != Constants.ANY - )).toArray + override def generateParts(): Array[CfgNode] = + (cpg.methodReturn.filter(x => x.typeFullName != Constants.ANY) ++ cpg.parameter.filter(x => + x.typeFullName != Constants.ANY + )).toArray - override def runOnPart(builder: DiffGraphBuilder, part: CfgNode): Unit = - part match { - case x: MethodReturn => runOnMethodReturn(builder, x) - case x: MethodParameterIn => runOnMethodParameter(builder, x) - case _ => - } + override def runOnPart(builder: DiffGraphBuilder, part: CfgNode): Unit = + part match + case x: MethodReturn => runOnMethodReturn(builder, x) + case x: MethodParameterIn => runOnMethodParameter(builder, x) + case _ => - private def runOnMethodReturn(diffGraph: DiffGraphBuilder, methodReturn: MethodReturn): Unit = - methodReturn.file.foreach { file => - val typeHint = methodReturn.typeFullName - val imports = fileToImports.getOrElse(file.name, List.empty) ++ methodReturn.method.typeDecl - .map(td => ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name))) - .toList - imports - .filter { x => - // TODO: Handle * imports correctly - x.alias.exists { imported => typeHint.matches(Pattern.quote(imported) + "(\\..+)*") } - } - .flatMap(_.entity) - .foreach { importedEntity => - setTypeHints(diffGraph, methodReturn, typeHint, typeHint, importedEntity) + private def runOnMethodReturn(diffGraph: DiffGraphBuilder, methodReturn: MethodReturn): Unit = + methodReturn.file.foreach { file => + val typeHint = methodReturn.typeFullName + val imports = + fileToImports.getOrElse(file.name, List.empty) ++ methodReturn.method.typeDecl + .map(td => + ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name)) + ) + .toList + imports + .filter { x => + // TODO: Handle * imports correctly + x.alias.exists { imported => + typeHint.matches(Pattern.quote(imported) + "(\\..+)*") + } + } + .flatMap(_.entity) + .foreach { importedEntity => + setTypeHints(diffGraph, methodReturn, typeHint, typeHint, importedEntity) + } } - } - private def runOnMethodParameter(diffGraph: DiffGraphBuilder, param: MethodParameterIn): Unit = - param.file.foreach { file => - val typeHint = param.typeFullName - val imports = fileToImports.getOrElse(file.name, List.empty) ++ param.method.typeDecl - .map(td => ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name))) - .toList - imports - // TODO: Handle * imports correctly - .filter(_.alias.exists { imported => typeHint.matches(Pattern.quote(imported) + "(\\..+)*") }) - .foreach { - case ImportScope(Some(importedEntity), Some(importedAs)) => - setTypeHints(diffGraph, param, typeHint, importedAs, importedEntity) - case _ => + private def runOnMethodParameter(diffGraph: DiffGraphBuilder, param: MethodParameterIn): Unit = + param.file.foreach { file => + val typeHint = param.typeFullName + val imports = fileToImports.getOrElse(file.name, List.empty) ++ param.method.typeDecl + .map(td => + ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name)) + ) + .toList + imports + // TODO: Handle * imports correctly + .filter(_.alias.exists { imported => + typeHint.matches(Pattern.quote(imported) + "(\\..+)*") + }) + .foreach { + case ImportScope(Some(importedEntity), Some(importedAs)) => + setTypeHints(diffGraph, param, typeHint, importedAs, importedEntity) + case _ => + } } - } - private def pythonicTypeNameToImport(fullName: String): String = - fullName.replaceFirst("\\.py:", "").replaceAll(Pattern.quote(File.separator), ".") + private def pythonicTypeNameToImport(fullName: String): String = + fullName.replaceFirst("\\.py:", "").replaceAll(Pattern.quote(File.separator), ".") - private def setTypeHints( - diffGraph: BatchedUpdate.DiffGraphBuilder, - node: StoredNode, - typeHint: String, - alias: String, - importedEntity: String - ) = { - val importFullPath = ImportStringHandling.combinedPath(importedEntity, typeHint) - val typeHintFullName = typeHint.replaceFirst(Pattern.quote(alias), importedEntity) - val typeFilePath = typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator)) - val pythonicTypeFullName = importFullPath.split("\\.").lastOption match { - case Some(typeName) => - typeFilePath.stripSuffix(s"${File.separator}$typeName").concat(s".py:.$typeName") - case None => typeHintFullName - } - cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match { - case xs if xs.sizeIs == 1 => - diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, xs.fullName.head) - case xs if xs.nonEmpty => - diffGraph.setNodeProperty(node, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, xs.fullName.toSeq) - case _ => - diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, pythonicTypeFullName) - } - } -} + private def setTypeHints( + diffGraph: BatchedUpdate.DiffGraphBuilder, + node: StoredNode, + typeHint: String, + alias: String, + importedEntity: String + ) = + val importFullPath = ImportStringHandling.combinedPath(importedEntity, typeHint) + val typeHintFullName = typeHint.replaceFirst(Pattern.quote(alias), importedEntity) + val typeFilePath = + typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator)) + val pythonicTypeFullName = importFullPath.split("\\.").lastOption match + case Some(typeName) => + typeFilePath.stripSuffix(s"${File.separator}$typeName").concat( + s".py:.$typeName" + ) + case None => typeHintFullName + cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match + case xs if xs.sizeIs == 1 => + diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, xs.fullName.head) + case xs if xs.nonEmpty => + diffGraph.setNodeProperty( + node, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + xs.fullName.toSeq + ) + case _ => + diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, pythonicTypeFullName) + end setTypeHints +end DynamicTypeHintFullNamePass diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala index 535cac23..5d2a6b6e 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala @@ -1,118 +1,106 @@ package io.appthreat.pysrc2cpg import io.shiftleft.codepropertygraph.generated.nodes.{ - NewBlock, - NewCall, - NewControlStructure, - NewFieldIdentifier, - NewFile, - NewIdentifier, - NewJumpTarget, - NewLiteral, - NewLocal, - NewMember, - NewMethod, - NewMethodParameterIn, - NewMethodRef, - NewMethodReturn, - NewModifier, - NewNamespaceBlock, - NewReturn, - NewTypeDecl, - NewTypeRef, - NewUnknown + NewBlock, + NewCall, + NewControlStructure, + NewFieldIdentifier, + NewFile, + NewIdentifier, + NewJumpTarget, + NewLiteral, + NewLocal, + NewMember, + NewMethod, + NewMethodParameterIn, + NewMethodRef, + NewMethodReturn, + NewModifier, + NewNamespaceBlock, + NewReturn, + NewTypeDecl, + NewTypeRef, + NewUnknown } import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import overflowdb.BatchedUpdate.DiffGraphBuilder -class EdgeBuilder(diffGraph: DiffGraphBuilder) { - def astEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, order: Int): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.AST) - addOrder(dstNode, order) - } +class EdgeBuilder(diffGraph: DiffGraphBuilder): + def astEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, order: Int): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.AST) + addOrder(dstNode, order) - def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argIndex: Int): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) - addArgumentIndex(dstNode, argIndex) - } + def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argIndex: Int): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) + addArgumentIndex(dstNode, argIndex) - def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argName: String): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) - // We need to fill something according to the CPG spec. But the spec also says that argument - // index is ignored if argument name is provided. So we just put -1. - addArgumentIndex(dstNode, -1) - addArgumentName(dstNode, argName) - } + def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argName: String): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) + // We need to fill something according to the CPG spec. But the spec also says that argument + // index is ignored if argument name is provided. So we just put -1. + addArgumentIndex(dstNode, -1) + addArgumentName(dstNode, argName) - def receiverEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.RECEIVER) - } + def receiverEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.RECEIVER) - def conditionEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CONDITION) - } + def conditionEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CONDITION) - def refEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.REF) - } + def refEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.REF) - def captureEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CAPTURE) - } + def captureEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CAPTURE) - def bindsEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = { - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.BINDS) - } + def bindsEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.BINDS) - private def addOrder(node: nodes.NewNode, order: Int): Unit = node match { - case n: NewTypeDecl => n.order = order - case n: NewBlock => n.order = order - case n: NewCall => n.order = order - case n: NewFieldIdentifier => n.order = order - case n: NewFile => n.order = order - case n: NewIdentifier => n.order = order - case n: NewLocal => n.order = order - case n: NewMethod => n.order = order - case n: NewMethodParameterIn => n.order = order - case n: NewMethodRef => n.order = order - case n: NewNamespaceBlock => n.order = order - case n: NewTypeRef => n.order = order - case n: NewUnknown => n.order = order - case n: NewModifier => n.order = order - case n: NewMethodReturn => n.order = order - case n: NewMember => n.order = order - case n: NewControlStructure => n.order = order - case n: NewLiteral => n.order = order - case n: NewReturn => n.order = order - case n: NewJumpTarget => n.order = order - } + private def addOrder(node: nodes.NewNode, order: Int): Unit = node match + case n: NewTypeDecl => n.order = order + case n: NewBlock => n.order = order + case n: NewCall => n.order = order + case n: NewFieldIdentifier => n.order = order + case n: NewFile => n.order = order + case n: NewIdentifier => n.order = order + case n: NewLocal => n.order = order + case n: NewMethod => n.order = order + case n: NewMethodParameterIn => n.order = order + case n: NewMethodRef => n.order = order + case n: NewNamespaceBlock => n.order = order + case n: NewTypeRef => n.order = order + case n: NewUnknown => n.order = order + case n: NewModifier => n.order = order + case n: NewMethodReturn => n.order = order + case n: NewMember => n.order = order + case n: NewControlStructure => n.order = order + case n: NewLiteral => n.order = order + case n: NewReturn => n.order = order + case n: NewJumpTarget => n.order = order - private def addArgumentIndex(node: nodes.NewNode, argIndex: Int): Unit = node match { - case n: NewBlock => n.argumentIndex = argIndex - case n: NewCall => n.argumentIndex = argIndex - case n: NewFieldIdentifier => n.argumentIndex = argIndex - case n: NewIdentifier => n.argumentIndex = argIndex - case n: NewMethodRef => n.argumentIndex = argIndex - case n: NewTypeRef => n.argumentIndex = argIndex - case n: NewUnknown => n.argumentIndex = argIndex - case n: NewControlStructure => n.argumentIndex = argIndex - case n: NewLiteral => n.argumentIndex = argIndex - case n: NewReturn => n.argumentIndex = argIndex - } + private def addArgumentIndex(node: nodes.NewNode, argIndex: Int): Unit = node match + case n: NewBlock => n.argumentIndex = argIndex + case n: NewCall => n.argumentIndex = argIndex + case n: NewFieldIdentifier => n.argumentIndex = argIndex + case n: NewIdentifier => n.argumentIndex = argIndex + case n: NewMethodRef => n.argumentIndex = argIndex + case n: NewTypeRef => n.argumentIndex = argIndex + case n: NewUnknown => n.argumentIndex = argIndex + case n: NewControlStructure => n.argumentIndex = argIndex + case n: NewLiteral => n.argumentIndex = argIndex + case n: NewReturn => n.argumentIndex = argIndex - private def addArgumentName(node: nodes.NewNode, argName: String): Unit = { - val someArgName = Some(argName) - node match { - case n: NewBlock => n.argumentName = someArgName - case n: NewCall => n.argumentName = someArgName - case n: NewFieldIdentifier => n.argumentName = someArgName - case n: NewIdentifier => n.argumentName = someArgName - case n: NewMethodRef => n.argumentName = someArgName - case n: NewTypeRef => n.argumentName = someArgName - case n: NewUnknown => n.argumentName = someArgName - case n: NewControlStructure => n.argumentName = someArgName - case n: NewLiteral => n.argumentName = someArgName - case n: NewReturn => n.argumentName = someArgName - } - } -} + private def addArgumentName(node: nodes.NewNode, argName: String): Unit = + val someArgName = Some(argName) + node match + case n: NewBlock => n.argumentName = someArgName + case n: NewCall => n.argumentName = someArgName + case n: NewFieldIdentifier => n.argumentName = someArgName + case n: NewIdentifier => n.argumentName = someArgName + case n: NewMethodRef => n.argumentName = someArgName + case n: NewTypeRef => n.argumentName = someArgName + case n: NewUnknown => n.argumentName = someArgName + case n: NewControlStructure => n.argumentName = someArgName + case n: NewLiteral => n.argumentName = someArgName + case n: NewReturn => n.argumentName = someArgName +end EdgeBuilder diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala index 5bc8655a..f0268a15 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala @@ -10,139 +10,162 @@ import io.shiftleft.semanticcpg.language.* import java.io.File as JFile import java.util.regex.{Matcher, Pattern} -class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { - - private lazy val root = cpg.metaData.root.headOption.getOrElse("").stripSuffix(JFile.separator) - - override protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit = { - val (namespace, entityName) = if (importedEntity.contains(".")) { - val splitName = importedEntity.split('.').toSeq - val namespace = importedEntity.stripSuffix(s".${splitName.last}") - (relativizeNamespace(namespace, fileName), splitName.last) - } else { - val currDir = BFile(root) / fileName match - case x if x.isDirectory => x - case x => x.parent - - val relCurrDir = currDir.pathAsString.stripPrefix(root).stripPrefix(JFile.separator) - - (relCurrDir, importedEntity) - } - - resolveEntities(namespace, entityName, importedAs).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) - } - - private def relativizeNamespace(path: String, fileName: String): String = if (path.startsWith(".")) { - // TODO: pysrc2cpg does not link files to the correct namespace nodes - val sep = Matcher.quoteReplacement(JFile.separator) - // The below gives us the full path of the relative "." - val relativeNamespace = - if (fileName.contains(JFile.separator)) - fileName.substring(0, fileName.lastIndexOf(JFile.separator)).replaceAll(sep, ".") - else "" - (if (path.length > 1) relativeNamespace + path.replaceAll(sep, ".") - else relativeNamespace).stripPrefix(".") - } else path - - /** For an import - given by its module path and the name of the imported function or module - determine the possible - * callee names. - * - * @param path - * the module path. - * @param expEntity - * the name of the imported entity. This could be a function, module, or variable/field. - * @param alias - * how the imported entity is named. - * @return - * the possible callee names - */ - private def resolveEntities(path: String, expEntity: String, alias: String): Set[ResolvedImport] = { - - implicit class ResolvedNodeExt(val traversal: Seq[String]) { - def toResolvedImport(cpg: Cpg): Seq[ResolvedImport] = { - val resolvedEntities = - traversal.flatMap(x => cpg.typeDecl.fullNameExact(x) ++ cpg.method.fullNameExact(x)).collect { - case x: Method => ResolvedMethod(x.fullName, alias) - case x: TypeDecl => ResolvedTypeDecl(x.fullName) - } - if (resolvedEntities.isEmpty) { - traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x)) - } else { - resolvedEntities - } - } - } - - implicit class CalleeAsInitExt(val name: String) { - def asInit: String = if (name.contains("__init__.py")) name - else name.replace(".py", s"${JFile.separator}__init__.py") - - def withInit: Seq[String] = Seq(name, name.asInit) - } - - val pathSep = "." - val sep = Matcher.quoteReplacement(JFile.separator) - val isMaybeConstructor = expEntity.split("\\.").lastOption.exists(s => s.nonEmpty && s.charAt(0).isUpper) - - lazy val membersMatchingImports: List[(TypeDecl, Member)] = cpg.typeDecl - .fullName(s".*${Pattern.quote(path)}.*") - .flatMap(t => - t.member.nameExact(expEntity).headOption match { - case Some(member) => Option((t, member)) - case None => None - } - ) - .toList - - (path match { - case "" if expEntity.contains(".") => - // Case 1: Qualified path: import foo.bar => (bar.py or bar/__init__.py) - val splitFunc = expEntity.split("\\.") - val name = splitFunc.tail.mkString(".") - s"${splitFunc(0)}.py:$pathSep$name".withInit.toResolvedImport(cpg) - case "" => - // Case 2: import of a module: import foo => (foo.py or foo/__init__.py) - s"$expEntity.py:".withInit.toResolvedImport(cpg) - case _ if membersMatchingImports.nonEmpty => - // Case 3: import of a variable: from api import db => (api.py or foo.__init__.py) @ identifier(db) - membersMatchingImports.map { - case (t, m) if t.method.nameExact(m.name).nonEmpty => - ResolvedMethod(t.method.nameExact(m.name).fullName.head, alias) - case (t, m) if t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact(m.name).nonEmpty => - ResolvedTypeDecl( - t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact(m.name).fullName.head +class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg): + + private lazy val root = cpg.metaData.root.headOption.getOrElse("").stripSuffix(JFile.separator) + + override protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit = + val (namespace, entityName) = if importedEntity.contains(".") then + val splitName = importedEntity.split('.').toSeq + val namespace = importedEntity.stripSuffix(s".${splitName.last}") + (relativizeNamespace(namespace, fileName), splitName.last) + else + val currDir = BFile(root) / fileName match + case x if x.isDirectory => x + case x => x.parent + + val relCurrDir = currDir.pathAsString.stripPrefix(root).stripPrefix(JFile.separator) + + (relCurrDir, importedEntity) + + resolveEntities(namespace, entityName, importedAs).foreach(x => + resolvedImportToTag(x, importCall, diffGraph) + ) + end optionalResolveImport + + private def relativizeNamespace(path: String, fileName: String): String = + if path.startsWith(".") then + // TODO: pysrc2cpg does not link files to the correct namespace nodes + val sep = Matcher.quoteReplacement(JFile.separator) + // The below gives us the full path of the relative "." + val relativeNamespace = + if fileName.contains(JFile.separator) then + fileName.substring(0, fileName.lastIndexOf(JFile.separator)).replaceAll( + sep, + "." + ) + else "" + (if path.length > 1 then relativeNamespace + path.replaceAll(sep, ".") + else relativeNamespace).stripPrefix(".") + else path + + /** For an import - given by its module path and the name of the imported function or module - + * determine the possible callee names. + * + * @param path + * the module path. + * @param expEntity + * the name of the imported entity. This could be a function, module, or variable/field. + * @param alias + * how the imported entity is named. + * @return + * the possible callee names + */ + private def resolveEntities( + path: String, + expEntity: String, + alias: String + ): Set[ResolvedImport] = + + implicit class ResolvedNodeExt(val traversal: Seq[String]): + def toResolvedImport(cpg: Cpg): Seq[ResolvedImport] = + val resolvedEntities = + traversal.flatMap(x => + cpg.typeDecl.fullNameExact(x) ++ cpg.method.fullNameExact(x) + ).collect { + case x: Method => ResolvedMethod(x.fullName, alias) + case x: TypeDecl => ResolvedTypeDecl(x.fullName) + } + if resolvedEntities.isEmpty then + traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x)) + else + resolvedEntities + + implicit class CalleeAsInitExt(val name: String): + def asInit: String = if name.contains("__init__.py") then name + else name.replace(".py", s"${JFile.separator}__init__.py") + + def withInit: Seq[String] = Seq(name, name.asInit) + + val pathSep = "." + val sep = Matcher.quoteReplacement(JFile.separator) + val isMaybeConstructor = + expEntity.split("\\.").lastOption.exists(s => s.nonEmpty && s.charAt(0).isUpper) + + lazy val membersMatchingImports: List[(TypeDecl, Member)] = cpg.typeDecl + .fullName(s".*${Pattern.quote(path)}.*") + .flatMap(t => + t.member.nameExact(expEntity).headOption match + case Some(member) => Option((t, member)) + case None => None ) - case (t, m) => ResolvedMember(t.fullName, m.name) - } - case _ => - // Case 4: Import from module using alias, e.g. import bar from foo as faz - val fileOrDir = BFile(codeRoot) / path - val pyFile = BFile(codeRoot) / s"$path.py" - fileOrDir match { - case f if f.isDirectory && !pyFile.exists => - Seq(s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:").toResolvedImport(cpg) - case f if f.isDirectory && (f / s"$expEntity.py").exists => - Seq(s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:").toResolvedImport(cpg) - case _ => - s"${path.replaceAll("\\.", sep)}.py:$pathSep$expEntity".withInit.toResolvedImport(cpg) - } - }).flatMap { - // If we import the constructor, we also import the type - case x: ResolvedMethod if isMaybeConstructor => - Seq(ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), ResolvedTypeDecl(x.fullName)) - // If we import the type, we also import the constructor - case x: ResolvedTypeDecl if isMaybeConstructor => - Seq(x, ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias)) - // If we can determine the import is a constructor, then it is likely not a member - case x: UnknownImport if isMaybeConstructor => - Seq(UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), UnknownTypeDecl(x.path)) - case x => Seq(x) - }.toSet - } -} + .toList + + (path match + case "" if expEntity.contains(".") => + // Case 1: Qualified path: import foo.bar => (bar.py or bar/__init__.py) + val splitFunc = expEntity.split("\\.") + val name = splitFunc.tail.mkString(".") + s"${splitFunc(0)}.py:$pathSep$name".withInit.toResolvedImport(cpg) + case "" => + // Case 2: import of a module: import foo => (foo.py or foo/__init__.py) + s"$expEntity.py:".withInit.toResolvedImport(cpg) + case _ if membersMatchingImports.nonEmpty => + // Case 3: import of a variable: from api import db => (api.py or foo.__init__.py) @ identifier(db) + membersMatchingImports.map { + case (t, m) if t.method.nameExact(m.name).nonEmpty => + ResolvedMethod(t.method.nameExact(m.name).fullName.head, alias) + case (t, m) + if t.astSiblings.isMethod.fullNameExact( + t.fullName + ).ast.isTypeDecl.nameExact(m.name).nonEmpty => + ResolvedTypeDecl( + t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact( + m.name + ).fullName.head + ) + case (t, m) => ResolvedMember(t.fullName, m.name) + } + case _ => + // Case 4: Import from module using alias, e.g. import bar from foo as faz + val fileOrDir = BFile(codeRoot) / path + val pyFile = BFile(codeRoot) / s"$path.py" + fileOrDir match + case f if f.isDirectory && !pyFile.exists => + Seq( + s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:" + ).toResolvedImport(cpg) + case f if f.isDirectory && (f / s"$expEntity.py").exists => + Seq( + s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:" + ).toResolvedImport(cpg) + case _ => + s"${path.replaceAll("\\.", sep)}.py:$pathSep$expEntity".withInit.toResolvedImport( + cpg + ) + ).flatMap { + // If we import the constructor, we also import the type + case x: ResolvedMethod if isMaybeConstructor => + Seq( + ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), + ResolvedTypeDecl(x.fullName) + ) + // If we import the type, we also import the constructor + case x: ResolvedTypeDecl if isMaybeConstructor => + Seq(x, ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias)) + // If we can determine the import is a constructor, then it is likely not a member + case x: UnknownImport if isMaybeConstructor => + Seq( + UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), + UnknownTypeDecl(x.path) + ) + case x => Seq(x) + }.toSet + end resolveEntities +end ImportResolverPass diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala index 26b63172..c4f791e8 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala @@ -2,24 +2,21 @@ package io.appthreat.pysrc2cpg import io.appthreat.x2cpg.passes.frontend.XImportsPass import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment -class ImportsPass(cpg: Cpg) extends XImportsPass(cpg) { +class ImportsPass(cpg: Cpg) extends XImportsPass(cpg): - override protected val importCallName: String = "import" + override protected val importCallName: String = "import" - override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = x.inAssignment.map(y => (x, y)) + override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = + x.inAssignment.map(y => (x, y)) - override def importedEntityFromCall(call: Call): String = { - call.argument.code.l match { - case List("", what) => what - case List(where, what) => s"$where.$what" - case List("", what, _) => what - case List(where, what, _) => s"$where.$what" - case _ => "" - } - } - -} + override def importedEntityFromCall(call: Call): String = + call.argument.code.l match + case List("", what) => what + case List(where, what) => s"$where.$what" + case List("", what, _) => what + case List(where, what, _) => s"$where.$what" + case _ => "" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala index 7fc3dfc0..f0291ba5 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala @@ -7,40 +7,40 @@ import scopt.OParser import java.nio.file.Paths -private object Frontend { - val cmdLineParser: OParser[Unit, Py2CpgOnFileSystemConfig] = { - val builder = OParser.builder[Py2CpgOnFileSystemConfig] - import builder._ - // Defaults for all command line options are specified in Py2CpgOFileSystemConfig - // because Scopt is a shit library. - OParser.sequence( - programName("pysrc2cpg"), - opt[String]("venvDir") - .text( - "Virtual environment directory. If not absolute it is interpreted relative to input-dir. Defaults to .venv." +private object Frontend: + val cmdLineParser: OParser[Unit, Py2CpgOnFileSystemConfig] = + val builder = OParser.builder[Py2CpgOnFileSystemConfig] + import builder.* + // Defaults for all command line options are specified in Py2CpgOFileSystemConfig + // because Scopt is a shit library. + OParser.sequence( + programName("pysrc2cpg"), + opt[String]("venvDir") + .text( + "Virtual environment directory. If not absolute it is interpreted relative to input-dir. Defaults to .venv." + ) + .action((dir, config) => config.withVenvDir(Paths.get(dir))), + opt[Boolean]("ignoreVenvDir") + .text("Specifies whether venv-dir is ignored. Default to true.") + .action(((value, config) => config.withIgnoreVenvDir(value))), + opt[Seq[String]]("ignore-paths") + .text( + "Ignores the specified path from analysis. If not absolute it is interpreted relative to input-dir." + ) + .action(((value, config) => config.withIgnorePaths(value.map(Paths.get(_))))), + opt[Seq[String]]("ignore-dir-names") + .text( + "Excludes all files where the relative path from input-dir contains at least one of names specified here." + ) + .action(((value, config) => config.withIgnoreDirNames(value))), + XTypeRecovery.parserOptions ) - .action((dir, config) => config.withVenvDir(Paths.get(dir))), - opt[Boolean]("ignoreVenvDir") - .text("Specifies whether venv-dir is ignored. Default to true.") - .action(((value, config) => config.withIgnoreVenvDir(value))), - opt[Seq[String]]("ignore-paths") - .text("Ignores the specified path from analysis. If not absolute it is interpreted relative to input-dir.") - .action(((value, config) => config.withIgnorePaths(value.map(Paths.get(_))))), - opt[Seq[String]]("ignore-dir-names") - .text( - "Excludes all files where the relative path from input-dir contains at least one of names specified here." - ) - .action(((value, config) => config.withIgnoreDirNames(value))), - XTypeRecovery.parserOptions - ) - } -} - -object NewMain extends X2CpgMain(cmdLineParser, new Py2CpgOnFileSystem())(new Py2CpgOnFileSystemConfig()) { - def run(config: Py2CpgOnFileSystemConfig, frontend: Py2CpgOnFileSystem): Unit = { - frontend.run(config) - } + end cmdLineParser +end Frontend - def getCmdLineParser = cmdLineParser +object NewMain + extends X2CpgMain(cmdLineParser, new Py2CpgOnFileSystem())(new Py2CpgOnFileSystemConfig()): + def run(config: Py2CpgOnFileSystemConfig, frontend: Py2CpgOnFileSystem): Unit = + frontend.run(config) -} + def getCmdLineParser = cmdLineParser diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala index e1f03271..60d97517 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala @@ -7,315 +7,320 @@ import io.appthreat.x2cpg.utils.NodeBuilders import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies, nodes} import overflowdb.BatchedUpdate.DiffGraphBuilder -class NodeBuilder(diffGraph: DiffGraphBuilder) { - - private def addNodeToDiff[T <: nodes.NewNode](node: T): T = { - diffGraph.addNode(node) - node - } - - def callNode(code: String, name: String, dispatchType: String, lineAndColumn: LineAndColumn): nodes.NewCall = { - val callNode = nodes - .NewCall() - .code(code) - .name(name) - .methodFullName(if (dispatchType == DispatchTypes.STATIC_DISPATCH) name else Defines.DynamicCallUnknownFullName) - .dispatchType(dispatchType) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(callNode) - } - - def typeNode(name: String, fullName: String): nodes.NewType = { - val typeNode = nodes - .NewType() - .name(name) - .fullName(fullName) - .typeDeclFullName(fullName) - addNodeToDiff(typeNode) - } - - def typeDeclNode( - name: String, - fullName: String, - fileName: String, - inheritsFromFullNames: collection.Seq[String], - lineAndColumn: LineAndColumn - ): nodes.NewTypeDecl = { - val typeDeclNode = nodes - .NewTypeDecl() - .name(name) - .fullName(fullName) - .isExternal(false) - .filename(fileName) - .inheritsFromTypeFullName(inheritsFromFullNames) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(typeDeclNode) - } - - def typeRefNode(code: String, typeFullName: String, lineAndColumn: LineAndColumn): nodes.NewTypeRef = { - val typeRefNode = nodes - .NewTypeRef() - .code(code) - .typeFullName(typeFullName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(typeRefNode) - } - - def memberNode(name: String): nodes.NewMember = { - val memberNode = nodes - .NewMember() - .code(name) - .name(name) - .typeFullName(Constants.ANY) - addNodeToDiff(memberNode) - } - - def memberNode(name: String, lineAndColumn: LineAndColumn): nodes.NewMember = { - memberNode(name) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - } - - def memberNode(name: String, dynamicTypeHintFullName: String): nodes.NewMember = - memberNode(name).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) - def memberNode(name: String, dynamicTypeHintFullName: String, lineAndColumn: LineAndColumn): nodes.NewMember = - memberNode(name, lineAndColumn).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) - - def bindingNode(): nodes.NewBinding = { - val bindingNode = nodes - .NewBinding() - .name("") - .signature("") - - addNodeToDiff(bindingNode) - } - - def methodNode(name: String, fullName: String, fileName: String, lineAndColumn: LineAndColumn): nodes.NewMethod = { - val methodNode = nodes - .NewMethod() - .name(name) - .fullName(fullName) - .filename(fileName) - .isExternal(false) - .lineNumber(lineAndColumn.line) - .lineNumberEnd(lineAndColumn.endLine) - .columnNumber(lineAndColumn.column) - .columnNumberEnd(lineAndColumn.endColumn) - addNodeToDiff(methodNode) - } - - def methodRefNode(name: String, fullName: String, lineAndColumn: LineAndColumn): nodes.NewMethodRef = { - val methodRefNode = nodes - .NewMethodRef() - .code(name) - .methodFullName(fullName) - .typeFullName(fullName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(methodRefNode) - } - - def closureBindingNode(closureBindingId: String, closureOriginalName: String): nodes.NewClosureBinding = { - val closureBindingNode = nodes - .NewClosureBinding() - .closureBindingId(Some(closureBindingId)) - .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) - .closureOriginalName(Some(closureOriginalName)) - addNodeToDiff(closureBindingNode) - } - - def methodParameterNode( - name: String, - isVariadic: Boolean, - lineAndColumn: LineAndColumn, - index: Option[Int] = None, - typeHint: Option[ast.iexpr] = None - ): nodes.NewMethodParameterIn = { - val methodParameterNode = nodes - .NewMethodParameterIn() - .name(name) - .code(name) - .evaluationStrategy(EvaluationStrategies.BY_SHARING) - .typeFullName(extractTypesFromHint(typeHint).getOrElse(Constants.ANY)) - .isVariadic(isVariadic) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - index.foreach(idx => methodParameterNode.index(idx)) - addNodeToDiff(methodParameterNode) - } - - def extractTypesFromHint(typeHint: Option[ast.iexpr] = None): Option[String] = { - typeHint match { - case Some(hint) => - val nameSequence = hint match { - case n: ast.Name => Option(n.id) - // TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type - // without type args. To see the type arguments, see ast.Subscript.slice - case attr: ast.Attribute => - extractTypesFromHint(Some(attr.value)).map { x => x + "." + attr.attr } - case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => Option(n.value.asInstanceOf[ast.Name].id) - case n: ast.Constant if n.value.isInstanceOf[ast.StringConstant] => - Option(n.value.asInstanceOf[ast.StringConstant].value) - case _ => None - } - nameSequence.map { typeName => - if (allBuiltinClasses.contains(typeName)) s"$builtinPrefix$typeName" - else if (typingClassesV3.contains(typeName)) s"$typingPrefix$typeName" - else typeName - } - case _ => None - } - } - - def methodReturnNode( - staticTypeHint: Option[String], - dynamicTypeHintFullName: Option[String], - lineAndColumn: LineAndColumn - ): nodes.NewMethodReturn = { - val methodReturnNode = NodeBuilders - .newMethodReturnNode( - staticTypeHint.getOrElse(Constants.ANY), - dynamicTypeHintFullName, - Some(lineAndColumn.line), - Some(lineAndColumn.column) - ) - .evaluationStrategy(EvaluationStrategies.BY_SHARING) - - addNodeToDiff(methodReturnNode) - } - - def returnNode(code: String, lineAndColumn: LineAndColumn): nodes.NewReturn = { - val returnNode = nodes - .NewReturn() - .code(code) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - - addNodeToDiff(returnNode) - } - - def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = { - val identifierNode = nodes - .NewIdentifier() - .code(name) - .name(name) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(identifierNode) - } - - def fieldIdentifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewFieldIdentifier = { - val fieldIdentifierNode = nodes - .NewFieldIdentifier() - .code(name) - .canonicalName(name) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(fieldIdentifierNode) - } - - def numberLiteralNode(number: Int, lineAndColumn: LineAndColumn): nodes.NewLiteral = { - numberLiteralNode(number.toString, lineAndColumn) - } - - def numberLiteralNode(number: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = { - val literalNode = nodes - .NewLiteral() - .code(number) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(literalNode) - } - - def stringLiteralNode(string: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = { - val literalNode = nodes - .NewLiteral() - .code(string) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(literalNode) - } - - def blockNode(code: String, lineAndColumn: LineAndColumn): nodes.NewBlock = { - val blockNode = nodes - .NewBlock() - .code(code) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(blockNode) - } - - def controlStructureNode( - code: String, - controlStructureName: String, - lineAndColumn: LineAndColumn - ): nodes.NewControlStructure = { - val controlStructureNode = nodes - .NewControlStructure() - .code(code) - .controlStructureType(controlStructureName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(controlStructureNode) - } - - def localNode(name: String, closureBindingId: Option[String] = None): nodes.NewLocal = { - val localNode = nodes - .NewLocal() - .code(name) - .name(name) - .closureBindingId(closureBindingId) - .typeFullName(Constants.ANY) - addNodeToDiff(localNode) - } - - def fileNode(fileName: String): nodes.NewFile = { - val fileNode = nodes - .NewFile() - .name(fileName) - addNodeToDiff(fileNode) - } - - def namespaceBlockNode(name: String, fullName: String, fileName: String): nodes.NewNamespaceBlock = { - val namespaceBlockNode = nodes - .NewNamespaceBlock() - .name(name) - .fullName(fullName) - .filename(fileName) - addNodeToDiff(namespaceBlockNode) - } - - def modifierNode(modifierType: String): nodes.NewModifier = { - val modifierNode = nodes - .NewModifier() - .modifierType(modifierType) - addNodeToDiff(modifierNode) - } - - def metaNode(language: String, version: String): nodes.NewMetaData = { - val metaNode = nodes - .NewMetaData() - .language(language) - .version(version) - addNodeToDiff(metaNode) - } - - def unknownNode(code: String, parserTypeName: String, lineAndColumn: LineAndColumn): nodes.NewUnknown = { - val unknownNode = nodes - .NewUnknown() - .code(code) - .parserTypeName(parserTypeName) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(unknownNode) - } -} +class NodeBuilder(diffGraph: DiffGraphBuilder): + + private def addNodeToDiff[T <: nodes.NewNode](node: T): T = + diffGraph.addNode(node) + node + + def callNode( + code: String, + name: String, + dispatchType: String, + lineAndColumn: LineAndColumn + ): nodes.NewCall = + val callNode = nodes + .NewCall() + .code(code) + .name(name) + .methodFullName(if dispatchType == DispatchTypes.STATIC_DISPATCH then name + else Defines.DynamicCallUnknownFullName) + .dispatchType(dispatchType) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(callNode) + + def typeNode(name: String, fullName: String): nodes.NewType = + val typeNode = nodes + .NewType() + .name(name) + .fullName(fullName) + .typeDeclFullName(fullName) + addNodeToDiff(typeNode) + + def typeDeclNode( + name: String, + fullName: String, + fileName: String, + inheritsFromFullNames: collection.Seq[String], + lineAndColumn: LineAndColumn + ): nodes.NewTypeDecl = + val typeDeclNode = nodes + .NewTypeDecl() + .name(name) + .fullName(fullName) + .isExternal(false) + .filename(fileName) + .inheritsFromTypeFullName(inheritsFromFullNames) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(typeDeclNode) + + def typeRefNode( + code: String, + typeFullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewTypeRef = + val typeRefNode = nodes + .NewTypeRef() + .code(code) + .typeFullName(typeFullName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(typeRefNode) + + def memberNode(name: String): nodes.NewMember = + val memberNode = nodes + .NewMember() + .code(name) + .name(name) + .typeFullName(Constants.ANY) + addNodeToDiff(memberNode) + + def memberNode(name: String, lineAndColumn: LineAndColumn): nodes.NewMember = + memberNode(name) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + + def memberNode(name: String, dynamicTypeHintFullName: String): nodes.NewMember = + memberNode(name).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) + def memberNode( + name: String, + dynamicTypeHintFullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMember = + memberNode(name, lineAndColumn).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) + + def bindingNode(): nodes.NewBinding = + val bindingNode = nodes + .NewBinding() + .name("") + .signature("") + + addNodeToDiff(bindingNode) + + def methodNode( + name: String, + fullName: String, + fileName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val methodNode = nodes + .NewMethod() + .name(name) + .fullName(fullName) + .filename(fileName) + .isExternal(false) + .lineNumber(lineAndColumn.line) + .lineNumberEnd(lineAndColumn.endLine) + .columnNumber(lineAndColumn.column) + .columnNumberEnd(lineAndColumn.endColumn) + addNodeToDiff(methodNode) + + def methodRefNode( + name: String, + fullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMethodRef = + val methodRefNode = nodes + .NewMethodRef() + .code(name) + .methodFullName(fullName) + .typeFullName(fullName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(methodRefNode) + + def closureBindingNode( + closureBindingId: String, + closureOriginalName: String + ): nodes.NewClosureBinding = + val closureBindingNode = nodes + .NewClosureBinding() + .closureBindingId(Some(closureBindingId)) + .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) + .closureOriginalName(Some(closureOriginalName)) + addNodeToDiff(closureBindingNode) + + def methodParameterNode( + name: String, + isVariadic: Boolean, + lineAndColumn: LineAndColumn, + index: Option[Int] = None, + typeHint: Option[ast.iexpr] = None + ): nodes.NewMethodParameterIn = + val methodParameterNode = nodes + .NewMethodParameterIn() + .name(name) + .code(name) + .evaluationStrategy(EvaluationStrategies.BY_SHARING) + .typeFullName(extractTypesFromHint(typeHint).getOrElse(Constants.ANY)) + .isVariadic(isVariadic) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + index.foreach(idx => methodParameterNode.index(idx)) + addNodeToDiff(methodParameterNode) + + def extractTypesFromHint(typeHint: Option[ast.iexpr] = None): Option[String] = + typeHint match + case Some(hint) => + val nameSequence = hint match + case n: ast.Name => Option(n.id) + // TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type + // without type args. To see the type arguments, see ast.Subscript.slice + case attr: ast.Attribute => + extractTypesFromHint(Some(attr.value)).map { x => x + "." + attr.attr } + case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => + Option(n.value.asInstanceOf[ast.Name].id) + case n: ast.Constant if n.value.isInstanceOf[ast.StringConstant] => + Option(n.value.asInstanceOf[ast.StringConstant].value) + case _ => None + nameSequence.map { typeName => + if allBuiltinClasses.contains(typeName) then s"$builtinPrefix$typeName" + else if typingClassesV3.contains(typeName) then s"$typingPrefix$typeName" + else typeName + } + case _ => None + + def methodReturnNode( + staticTypeHint: Option[String], + dynamicTypeHintFullName: Option[String], + lineAndColumn: LineAndColumn + ): nodes.NewMethodReturn = + val methodReturnNode = NodeBuilders + .newMethodReturnNode( + staticTypeHint.getOrElse(Constants.ANY), + dynamicTypeHintFullName, + Some(lineAndColumn.line), + Some(lineAndColumn.column) + ) + .evaluationStrategy(EvaluationStrategies.BY_SHARING) + + addNodeToDiff(methodReturnNode) + + def returnNode(code: String, lineAndColumn: LineAndColumn): nodes.NewReturn = + val returnNode = nodes + .NewReturn() + .code(code) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + + addNodeToDiff(returnNode) + + def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = + val identifierNode = nodes + .NewIdentifier() + .code(name) + .name(name) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(identifierNode) + + def fieldIdentifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewFieldIdentifier = + val fieldIdentifierNode = nodes + .NewFieldIdentifier() + .code(name) + .canonicalName(name) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(fieldIdentifierNode) + + def numberLiteralNode(number: Int, lineAndColumn: LineAndColumn): nodes.NewLiteral = + numberLiteralNode(number.toString, lineAndColumn) + + def numberLiteralNode(number: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = + val literalNode = nodes + .NewLiteral() + .code(number) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(literalNode) + + def stringLiteralNode(string: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = + val literalNode = nodes + .NewLiteral() + .code(string) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(literalNode) + + def blockNode(code: String, lineAndColumn: LineAndColumn): nodes.NewBlock = + val blockNode = nodes + .NewBlock() + .code(code) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(blockNode) + + def controlStructureNode( + code: String, + controlStructureName: String, + lineAndColumn: LineAndColumn + ): nodes.NewControlStructure = + val controlStructureNode = nodes + .NewControlStructure() + .code(code) + .controlStructureType(controlStructureName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(controlStructureNode) + + def localNode(name: String, closureBindingId: Option[String] = None): nodes.NewLocal = + val localNode = nodes + .NewLocal() + .code(name) + .name(name) + .closureBindingId(closureBindingId) + .typeFullName(Constants.ANY) + addNodeToDiff(localNode) + + def fileNode(fileName: String): nodes.NewFile = + val fileNode = nodes + .NewFile() + .name(fileName) + addNodeToDiff(fileNode) + + def namespaceBlockNode( + name: String, + fullName: String, + fileName: String + ): nodes.NewNamespaceBlock = + val namespaceBlockNode = nodes + .NewNamespaceBlock() + .name(name) + .fullName(fullName) + .filename(fileName) + addNodeToDiff(namespaceBlockNode) + + def modifierNode(modifierType: String): nodes.NewModifier = + val modifierNode = nodes + .NewModifier() + .modifierType(modifierType) + addNodeToDiff(modifierNode) + + def metaNode(language: String, version: String): nodes.NewMetaData = + val metaNode = nodes + .NewMetaData() + .language(language) + .version(version) + addNodeToDiff(metaNode) + + def unknownNode( + code: String, + parserTypeName: String, + lineAndColumn: LineAndColumn + ): nodes.NewUnknown = + val unknownNode = nodes + .NewUnknown() + .code(code) + .parserTypeName(parserTypeName) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(unknownNode) +end NodeBuilder diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala index c1ddd77b..e37bbd4a 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala @@ -2,8 +2,6 @@ package io.appthreat.pysrc2cpg import io.appthreat.pythonparser.ast -class NodeToCode(content: String) { - def getCode(node: ast.iattributes): String = { - content.substring(node.input_offset, node.end_input_offset) - } -} +class NodeToCode(content: String): + def getCode(node: ast.iattributes): String = + content.substring(node.input_offset, node.end_input_offset) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala index 936c4ac2..aea27100 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala @@ -6,15 +6,15 @@ import io.shiftleft.codepropertygraph.generated.Languages import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder -object Py2Cpg { - case class InputPair(content: String, relFileName: String) - type InputProvider = () => InputPair -} +object Py2Cpg: + case class InputPair(content: String, relFileName: String) + type InputProvider = () => InputPair /** Entry point for general cpg generation from python code. * * @param inputProviders - * Set of functions which provide InputPairs. The functions must be safe to call from different threads. + * Set of functions which provide InputPairs. The functions must be safe to call from different + * threads. * @param outputCpg * Empty target cpg which will be populated. * @param inputPath @@ -30,21 +30,33 @@ class Py2Cpg( inputPath: String, requirementsTxt: String = "requirements.txt", schemaValidationMode: ValidationMode -) { - private val diffGraph = new DiffGraphBuilder() - private val nodeBuilder = new NodeBuilder(diffGraph) - private val edgeBuilder = new EdgeBuilder(diffGraph) +): + private val diffGraph = new DiffGraphBuilder() + private val nodeBuilder = new NodeBuilder(diffGraph) + private val edgeBuilder = new EdgeBuilder(diffGraph) - def buildCpg(): Unit = { - nodeBuilder.metaNode(Languages.PYTHONSRC, version = "").root(inputPath + java.io.File.separator) - val globalNamespaceBlock = - nodeBuilder.namespaceBlockNode(Constants.GLOBAL_NAMESPACE, Constants.GLOBAL_NAMESPACE, "N/A") - nodeBuilder.typeNode(Constants.ANY, Constants.ANY) - val anyTypeDecl = nodeBuilder.typeDeclNode(Constants.ANY, Constants.ANY, "N/A", Nil, LineAndColumn(1, 1, 1, 1)) - edgeBuilder.astEdge(anyTypeDecl, globalNamespaceBlock, 0) - BatchedUpdate.applyDiff(outputCpg.graph, diffGraph) - new CodeToCpg(outputCpg, inputProviders, schemaValidationMode).createAndApply() - new ConfigFileCreationPass(outputCpg, requirementsTxt).createAndApply() - new DependenciesFromRequirementsTxtPass(outputCpg).createAndApply() - } -} + def buildCpg(): Unit = + nodeBuilder.metaNode(Languages.PYTHONSRC, version = "").root( + inputPath + java.io.File.separator + ) + val globalNamespaceBlock = + nodeBuilder.namespaceBlockNode( + Constants.GLOBAL_NAMESPACE, + Constants.GLOBAL_NAMESPACE, + "N/A" + ) + nodeBuilder.typeNode(Constants.ANY, Constants.ANY) + val anyTypeDecl = nodeBuilder.typeDeclNode( + Constants.ANY, + Constants.ANY, + "N/A", + Nil, + LineAndColumn(1, 1, 1, 1) + ) + edgeBuilder.astEdge(anyTypeDecl, globalNamespaceBlock, 0) + BatchedUpdate.applyDiff(outputCpg.graph, diffGraph) + new CodeToCpg(outputCpg, inputProviders, schemaValidationMode).createAndApply() + new ConfigFileCreationPass(outputCpg, requirementsTxt).createAndApply() + new DependenciesFromRequirementsTxtPass(outputCpg).createAndApply() + end buildCpg +end Py2Cpg diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala index 7470da85..a255b883 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala @@ -8,7 +8,7 @@ import org.slf4j.LoggerFactory import java.nio.file.* import scala.util.Try -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* case class Py2CpgOnFileSystemConfig( venvDir: Path = Paths.get(".venv"), @@ -17,89 +17,88 @@ case class Py2CpgOnFileSystemConfig( ignoreDirNames: Seq[String] = Nil, requirementsTxt: String = "requirements.txt" ) extends X2CpgConfig[Py2CpgOnFileSystemConfig] - with TypeRecoveryParserConfig[Py2CpgOnFileSystemConfig] { - def withVenvDir(venvDir: Path): Py2CpgOnFileSystemConfig = { - copy(venvDir = venvDir).withInheritedFields(this) - } - - def withIgnoreVenvDir(value: Boolean): Py2CpgOnFileSystemConfig = { - copy(ignoreVenvDir = value).withInheritedFields(this) - } - - def withIgnorePaths(value: Seq[Path]): Py2CpgOnFileSystemConfig = { - copy(ignorePaths = value).withInheritedFields(this) - } - - def withIgnoreDirNames(value: Seq[String]): Py2CpgOnFileSystemConfig = { - copy(ignoreDirNames = value).withInheritedFields(this) - } - - def withRequirementsTxt(text: String): Py2CpgOnFileSystemConfig = { - copy(requirementsTxt = text).withInheritedFields(this) - } -} - -class Py2CpgOnFileSystem extends X2CpgFrontend[Py2CpgOnFileSystemConfig] { - private val logger = LoggerFactory.getLogger(getClass) - - /** Entry point for files system based cpg generation from python code. - * @param config - * Configuration for cpg generation. - */ - override def createCpg(config: Py2CpgOnFileSystemConfig): Try[Cpg] = { - logConfiguration(config) - - X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => - val venvIgnorePath = - if (config.ignoreVenvDir) { - config.venvDir :: Nil - } else { - Nil + with TypeRecoveryParserConfig[Py2CpgOnFileSystemConfig]: + def withVenvDir(venvDir: Path): Py2CpgOnFileSystemConfig = + copy(venvDir = venvDir).withInheritedFields(this) + + def withIgnoreVenvDir(value: Boolean): Py2CpgOnFileSystemConfig = + copy(ignoreVenvDir = value).withInheritedFields(this) + + def withIgnorePaths(value: Seq[Path]): Py2CpgOnFileSystemConfig = + copy(ignorePaths = value).withInheritedFields(this) + + def withIgnoreDirNames(value: Seq[String]): Py2CpgOnFileSystemConfig = + copy(ignoreDirNames = value).withInheritedFields(this) + + def withRequirementsTxt(text: String): Py2CpgOnFileSystemConfig = + copy(requirementsTxt = text).withInheritedFields(this) +end Py2CpgOnFileSystemConfig + +class Py2CpgOnFileSystem extends X2CpgFrontend[Py2CpgOnFileSystemConfig]: + private val logger = LoggerFactory.getLogger(getClass) + + /** Entry point for files system based cpg generation from python code. + * @param config + * Configuration for cpg generation. + */ + override def createCpg(config: Py2CpgOnFileSystemConfig): Try[Cpg] = + logConfiguration(config) + + X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => + val venvIgnorePath = + if config.ignoreVenvDir then + config.venvDir :: Nil + else + Nil + val inputPath = Path.of(config.inputPath) + val ignoreDirNamesSet = config.ignoreDirNames.toSet + val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => + inputPath.resolve(path) + } + + val inputFiles = SourceFiles + .determine(config.inputPath, Set(".py"), config) + .map(x => Path.of(x)) + .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } + .filter { file => + !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) + } + + val inputProviders = inputFiles.map { inputFile => () => + val content = IOUtils.readLinesInFile(inputFile).mkString("\n") + Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) + } + val py2Cpg = new Py2Cpg( + inputProviders, + cpg, + config.inputPath, + config.requirementsTxt, + config.schemaValidation + ) + py2Cpg.buildCpg() } - val inputPath = Path.of(config.inputPath) - val ignoreDirNamesSet = config.ignoreDirNames.toSet - val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => - inputPath.resolve(path) - } - - val inputFiles = SourceFiles - .determine(config.inputPath, Set(".py"), config) - .map(x => Path.of(x)) - .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } - .filter { file => - !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) - } - - val inputProviders = inputFiles.map { inputFile => () => - { - val content = IOUtils.readLinesInFile(inputFile).mkString("\n") - Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) - } - } - val py2Cpg = new Py2Cpg(inputProviders, cpg, config.inputPath, config.requirementsTxt, config.schemaValidation) - py2Cpg.buildCpg() - } - } - - private def filterIgnoreDirNames(file: Path, inputPath: Path, ignoreDirNamesSet: Set[String]): Boolean = { - var parts = inputPath.relativize(file).iterator().asScala.toList - - if (!Files.isDirectory(file)) { - // we're only interested in the directories - drop the file part - parts = parts.dropRight(1) - } - - val aPartIsInIgnoreSet = parts.exists(part => ignoreDirNamesSet.contains(part.toString)) - !aPartIsInIgnoreSet - } - - private def logConfiguration(config: Py2CpgOnFileSystemConfig): Unit = { - logger.debug(s"Output file: ${config.outputPath}") - logger.debug(s"Input directory: ${config.inputPath}") - logger.debug(s"Venv directory: ${config.venvDir}") - logger.debug(s"IgnoreVenvDir: ${config.ignoreVenvDir}") - logger.debug(s"IgnorePaths: ${config.ignorePaths.mkString(", ")}") - logger.debug(s"IgnoreDirNames: ${config.ignoreDirNames.mkString(", ")}") - logger.debug(s"No dummy types: ${config.disableDummyTypes}") - } -} + end createCpg + + private def filterIgnoreDirNames( + file: Path, + inputPath: Path, + ignoreDirNamesSet: Set[String] + ): Boolean = + var parts = inputPath.relativize(file).iterator().asScala.toList + + if !Files.isDirectory(file) then + // we're only interested in the directories - drop the file part + parts = parts.dropRight(1) + + val aPartIsInIgnoreSet = parts.exists(part => ignoreDirNamesSet.contains(part.toString)) + !aPartIsInIgnoreSet + + private def logConfiguration(config: Py2CpgOnFileSystemConfig): Unit = + logger.debug(s"Output file: ${config.outputPath}") + logger.debug(s"Input directory: ${config.inputPath}") + logger.debug(s"Venv directory: ${config.venvDir}") + logger.debug(s"IgnoreVenvDir: ${config.ignoreVenvDir}") + logger.debug(s"IgnorePaths: ${config.ignorePaths.mkString(", ")}") + logger.debug(s"IgnoreDirNames: ${config.ignoreDirNames.mkString(", ")}") + logger.debug(s"No dummy types: ${config.disableDummyTypes}") +end Py2CpgOnFileSystem diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala index 32a19aea..a69b7f6d 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala @@ -1,6 +1,12 @@ package io.appthreat.pysrc2cpg -import io.appthreat.pysrc2cpg.memop.{AstNodeToMemoryOperationMap, Del, Load, MemoryOperationCalculator, Store} +import io.appthreat.pysrc2cpg.memop.{ + AstNodeToMemoryOperationMap, + Del, + Load, + MemoryOperationCalculator, + Store +} import PythonAstVisitor.{builtinPrefix, metaClassSuffix} import io.appthreat.pythonparser.ast import io.appthreat.pysrc2cpg.memop.* @@ -11,2255 +17,2401 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder import scala.collection.mutable -object MethodParameters { - def empty(): MethodParameters = { - new MethodParameters(0, Nil) - } -} -case class MethodParameters(posStartIndex: Int, positionalParams: Iterable[nodes.NewMethodParameterIn]) +object MethodParameters: + def empty(): MethodParameters = + new MethodParameters(0, Nil) +case class MethodParameters( + posStartIndex: Int, + positionalParams: Iterable[nodes.NewMethodParameterIn] +) sealed trait PythonVersion object PythonV2 extends PythonVersion object PythonV3 extends PythonVersion object PythonV2AndV3 extends PythonVersion -class PythonAstVisitor(relFileName: String, protected val nodeToCode: NodeToCode, version: PythonVersion)(implicit - withSchemaValidation: ValidationMode -) extends PythonAstVisitorHelpers { +class PythonAstVisitor( + relFileName: String, + protected val nodeToCode: NodeToCode, + version: PythonVersion +)(implicit withSchemaValidation: ValidationMode) extends PythonAstVisitorHelpers: - private val diffGraph = new DiffGraphBuilder() - protected val nodeBuilder = new NodeBuilder(diffGraph) - protected val edgeBuilder = new EdgeBuilder(diffGraph) + private val diffGraph = new DiffGraphBuilder() + protected val nodeBuilder = new NodeBuilder(diffGraph) + protected val edgeBuilder = new EdgeBuilder(diffGraph) - protected val contextStack = new ContextStack() + protected val contextStack = new ContextStack() - private var memOpMap: AstNodeToMemoryOperationMap = _ + private var memOpMap: AstNodeToMemoryOperationMap = _ - private val members = mutable.Map.empty[NewTypeDecl, List[String]] + private val members = mutable.Map.empty[NewTypeDecl, List[String]] - // As key only ast.FunctionDef and ast.AsyncFunctionDef are used but there - // is no more specific type than ast.istmt. - private val functionDefToMethod = mutable.Map.empty[ast.istmt, nodes.NewMethod] + // As key only ast.FunctionDef and ast.AsyncFunctionDef are used but there + // is no more specific type than ast.istmt. + private val functionDefToMethod = mutable.Map.empty[ast.istmt, nodes.NewMethod] - def getDiffGraph: DiffGraphBuilder = { - diffGraph - } + def getDiffGraph: DiffGraphBuilder = + diffGraph - private def createIdentifierLinks(): Unit = { - contextStack.createIdentifierLinks( - nodeBuilder.localNode, - nodeBuilder.closureBindingNode, - edgeBuilder.astEdge, - edgeBuilder.refEdge, - edgeBuilder.captureEdge - ) - } - - def convert(astNode: ast.iast): NewNode = { - astNode match { - case module: ast.Module => convert(module) - } - } - - def convert(mod: ast.imod): NewNode = { - mod match { - case node: ast.Module => convert(node) - } - } - - // Entry method for the visitor. - def convert(module: ast.Module): NewNode = { - val memOpCalculator = new MemoryOperationCalculator() - module.accept(memOpCalculator) - memOpMap = memOpCalculator.astNodeToMemOp - - val fileNode = nodeBuilder.fileNode(relFileName) - val namespaceBlockNode = - nodeBuilder.namespaceBlockNode( - Constants.GLOBAL_NAMESPACE, - relFileName + ":" + Constants.GLOBAL_NAMESPACE, - relFileName - ) - edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1) - contextStack.setFileNamespaceBlock(namespaceBlockNode) - - val methodFullName = calculateFullNameFromContext("") - - val firstLineAndCol = module.stmts.headOption.map(lineAndColOf) - val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf) - val line = firstLineAndCol.map(_.line).getOrElse(1) - val column = firstLineAndCol.map(_.column).getOrElse(1) - val endLine = lastLineAndCol.map(_.endLine).getOrElse(1) - val endColumn = lastLineAndCol.map(_.endColumn).getOrElse(1) - - val moduleMethodNode = - createMethod( - "", - methodFullName, - Some(""), - parameterProvider = () => MethodParameters.empty(), - bodyProvider = () => createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = None, - LineAndColumn(line, column, endLine, endColumn) - ) - - createIdentifierLinks() - - moduleMethodNode - } - - // Create assignments of type references to all builtin identifiers if the identifier appears - // at least once in the module. We filter in order to not generate the complete list of - // assignment in each module. - // That logic still generates assignments in cases where we would not need to, but figuring - // that out would mean we need to wait until all identifiers are linked which is than too - // late to create new identifiers and still use the same link mechanism. We would need - // to rearrange quite some code to accomplish that. So we leave that as an optional TODO. - // Note that namesUsedInModule is only calculated from ast.Name nodes! So e.g. new names - // artificially created during lowering are not in that collection which is fine for now. - private def createBuiltinIdentifiers(namesUsedInModule: collection.Set[String]): Iterable[nodes.NewNode] = { - val result = mutable.ArrayBuffer.empty[nodes.NewNode] - val lineAndColumn = LineAndColumn(1, 1, 1, 1) - - val builtinFunctions = mutable.ArrayBuffer.empty[String] - val builtinClasses = mutable.ArrayBuffer.empty[String] - - if (version == PythonV3 || version == PythonV2AndV3) { - builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV3) - builtinClasses.appendAll(PythonAstVisitor.builtinClassesV3) - } - if (version == PythonV2 || version == PythonV2AndV3) { - builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV2) - builtinClasses.appendAll(PythonAstVisitor.builtinClassesV2) - } - - builtinFunctions.distinct.foreach { builtinObjectName => - if (namesUsedInModule.contains(builtinObjectName)) { - val assignmentNode = createAssignment( - createIdentifierNode(builtinObjectName, Store, lineAndColumn), - nodeBuilder - .typeRefNode("__builtins__." + builtinObjectName, builtinPrefix + builtinObjectName, lineAndColumn), - lineAndColumn + private def createIdentifierLinks(): Unit = + contextStack.createIdentifierLinks( + nodeBuilder.localNode, + nodeBuilder.closureBindingNode, + edgeBuilder.astEdge, + edgeBuilder.refEdge, + edgeBuilder.captureEdge ) - result.append(assignmentNode) - } - } - - builtinClasses.distinct.foreach { builtinObjectName => - if (namesUsedInModule.contains(builtinObjectName)) { - val assignmentNode = createAssignment( - createIdentifierNode(builtinObjectName, Store, lineAndColumn), - nodeBuilder.typeRefNode( - "__builtins__." + builtinObjectName, - builtinPrefix + builtinObjectName + metaClassSuffix, - lineAndColumn + def convert(astNode: ast.iast): NewNode = + astNode match + case module: ast.Module => convert(module) + + def convert(mod: ast.imod): NewNode = + mod match + case node: ast.Module => convert(node) + + // Entry method for the visitor. + def convert(module: ast.Module): NewNode = + val memOpCalculator = new MemoryOperationCalculator() + module.accept(memOpCalculator) + memOpMap = memOpCalculator.astNodeToMemOp + + val fileNode = nodeBuilder.fileNode(relFileName) + val namespaceBlockNode = + nodeBuilder.namespaceBlockNode( + Constants.GLOBAL_NAMESPACE, + relFileName + ":" + Constants.GLOBAL_NAMESPACE, + relFileName + ) + edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1) + contextStack.setFileNamespaceBlock(namespaceBlockNode) + + val methodFullName = calculateFullNameFromContext("") + + val firstLineAndCol = module.stmts.headOption.map(lineAndColOf) + val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf) + val line = firstLineAndCol.map(_.line).getOrElse(1) + val column = firstLineAndCol.map(_.column).getOrElse(1) + val endLine = lastLineAndCol.map(_.endLine).getOrElse(1) + val endColumn = lastLineAndCol.map(_.endColumn).getOrElse(1) + + val moduleMethodNode = + createMethod( + "", + methodFullName, + Some(""), + parameterProvider = () => MethodParameters.empty(), + bodyProvider = () => + createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = None, + LineAndColumn(line, column, endLine, endColumn) + ) + + createIdentifierLinks() + + moduleMethodNode + end convert + + // Create assignments of type references to all builtin identifiers if the identifier appears + // at least once in the module. We filter in order to not generate the complete list of + // assignment in each module. + // That logic still generates assignments in cases where we would not need to, but figuring + // that out would mean we need to wait until all identifiers are linked which is than too + // late to create new identifiers and still use the same link mechanism. We would need + // to rearrange quite some code to accomplish that. So we leave that as an optional TODO. + // Note that namesUsedInModule is only calculated from ast.Name nodes! So e.g. new names + // artificially created during lowering are not in that collection which is fine for now. + private def createBuiltinIdentifiers(namesUsedInModule: collection.Set[String]) + : Iterable[nodes.NewNode] = + val result = mutable.ArrayBuffer.empty[nodes.NewNode] + val lineAndColumn = LineAndColumn(1, 1, 1, 1) + + val builtinFunctions = mutable.ArrayBuffer.empty[String] + val builtinClasses = mutable.ArrayBuffer.empty[String] + + if version == PythonV3 || version == PythonV2AndV3 then + builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV3) + builtinClasses.appendAll(PythonAstVisitor.builtinClassesV3) + if version == PythonV2 || version == PythonV2AndV3 then + builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV2) + builtinClasses.appendAll(PythonAstVisitor.builtinClassesV2) + + builtinFunctions.distinct.foreach { builtinObjectName => + if namesUsedInModule.contains(builtinObjectName) then + val assignmentNode = createAssignment( + createIdentifierNode(builtinObjectName, Store, lineAndColumn), + nodeBuilder + .typeRefNode( + "__builtins__." + builtinObjectName, + builtinPrefix + builtinObjectName, + lineAndColumn + ), + lineAndColumn + ) + + result.append(assignmentNode) + } + + builtinClasses.distinct.foreach { builtinObjectName => + if namesUsedInModule.contains(builtinObjectName) then + val assignmentNode = createAssignment( + createIdentifierNode(builtinObjectName, Store, lineAndColumn), + nodeBuilder.typeRefNode( + "__builtins__." + builtinObjectName, + builtinPrefix + builtinObjectName + metaClassSuffix, + lineAndColumn + ), + lineAndColumn + ) + + result.append(assignmentNode) + } + + result + end createBuiltinIdentifiers + + private def unhandled(node: ast.iast with ast.iattributes): NewNode = + val unhandledAsUnknown = true + if unhandledAsUnknown then + nodeBuilder.unknownNode(node.toString, node.getClass.getName, lineAndColOf(node)) + else + throw new NotImplementedError() + + def convert(stmt: ast.istmt): NewNode = + stmt match + case node: ast.FunctionDef => convert(node) + case node: ast.AsyncFunctionDef => convert(node) + case node: ast.ClassDef => convert(node) + case node: ast.Return => convert(node) + case node: ast.Delete => convert(node) + case node: ast.Assign => convert(node) + case node: ast.AnnAssign => convert(node) + case node: ast.AugAssign => convert(node) + case node: ast.For => convert(node) + case node: ast.AsyncFor => convert(node) + case node: ast.While => convert(node) + case node: ast.If => convert(node) + case node: ast.With => convert(node) + case node: ast.AsyncWith => convert(node) + case node: ast.Match => convert(node) + case node: ast.Raise => convert(node) + case node: ast.Try => convert(node) + case node: ast.Assert => convert(node) + case node: ast.Import => convert(node) + case node: ast.ImportFrom => convert(node) + case node: ast.Global => convert(node) + case node: ast.Nonlocal => convert(node) + case node: ast.Expr => convert(node) + case node: ast.Pass => convert(node) + case node: ast.Break => convert(node) + case node: ast.Continue => convert(node) + case node: ast.RaiseP2 => unhandled(node) + case node: ast.ErrorStatement => convert(node) + + def convert(functionDef: ast.FunctionDef): NewNode = + val methodIdentifierNode = + createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) + val (methodNode, methodRefNode) = createMethodAndMethodRef( + functionDef.name, + Some(functionDef.name), + createParameterProcessingFunction( + functionDef.args, + isStaticMethod(functionDef.decorator_list) ), - lineAndColumn + () => functionDef.body.map(convert), + functionDef.returns, + isAsync = false, + lineAndColOf(functionDef) + ) + functionDefToMethod.put(functionDef, methodNode) + + val wrappedMethodRefNode = + wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + + createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) + end convert + + /* + * For a decorated function like: + * @f1(arg) + * @f2 + * def func(): pass + * + * The lowering is: + * func = f1(arg)(f2(func)) + * + * This function takes a method ref, wrappes it in the decorator calls and returns the resulting expression. + * In the example case this is: + * f1(arg)(f2(func)) + */ + def wrapMethodRefWithDecorators( + methodRefNode: nodes.NewNode, + decoratorList: Iterable[ast.iexpr] + ): nodes.NewNode = + decoratorList.foldRight(methodRefNode)( + (decorator: ast.iexpr, wrappedMethodRef: nodes.NewNode) => + createCall( + convert(decorator), + "", + lineAndColOf(decorator), + wrappedMethodRef :: Nil, + Nil + ) ) - result.append(assignmentNode) - } - } - - result - } - - private def unhandled(node: ast.iast with ast.iattributes): NewNode = { - val unhandledAsUnknown = true - if (unhandledAsUnknown) { - nodeBuilder.unknownNode(node.toString, node.getClass.getName, lineAndColOf(node)) - } else { - throw new NotImplementedError() - } - } - - def convert(stmt: ast.istmt): NewNode = { - stmt match { - case node: ast.FunctionDef => convert(node) - case node: ast.AsyncFunctionDef => convert(node) - case node: ast.ClassDef => convert(node) - case node: ast.Return => convert(node) - case node: ast.Delete => convert(node) - case node: ast.Assign => convert(node) - case node: ast.AnnAssign => convert(node) - case node: ast.AugAssign => convert(node) - case node: ast.For => convert(node) - case node: ast.AsyncFor => convert(node) - case node: ast.While => convert(node) - case node: ast.If => convert(node) - case node: ast.With => convert(node) - case node: ast.AsyncWith => convert(node) - case node: ast.Match => convert(node) - case node: ast.Raise => convert(node) - case node: ast.Try => convert(node) - case node: ast.Assert => convert(node) - case node: ast.Import => convert(node) - case node: ast.ImportFrom => convert(node) - case node: ast.Global => convert(node) - case node: ast.Nonlocal => convert(node) - case node: ast.Expr => convert(node) - case node: ast.Pass => convert(node) - case node: ast.Break => convert(node) - case node: ast.Continue => convert(node) - case node: ast.RaiseP2 => unhandled(node) - case node: ast.ErrorStatement => convert(node) - } - } - - def convert(functionDef: ast.FunctionDef): NewNode = { - val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) - val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction(functionDef.args, isStaticMethod(functionDef.decorator_list)), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = false, - lineAndColOf(functionDef) - ) - functionDefToMethod.put(functionDef, methodNode) - - val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) - - createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) - } - - /* - * For a decorated function like: - * @f1(arg) - * @f2 - * def func(): pass - * - * The lowering is: - * func = f1(arg)(f2(func)) - * - * This function takes a method ref, wrappes it in the decorator calls and returns the resulting expression. - * In the example case this is: - * f1(arg)(f2(func)) - */ - def wrapMethodRefWithDecorators(methodRefNode: nodes.NewNode, decoratorList: Iterable[ast.iexpr]): nodes.NewNode = { - decoratorList.foldRight(methodRefNode)((decorator: ast.iexpr, wrappedMethodRef: nodes.NewNode) => - createCall(convert(decorator), "", lineAndColOf(decorator), wrappedMethodRef :: Nil, Nil) - ) - } - - def convert(functionDef: ast.AsyncFunctionDef): NewNode = { - val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) - val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction(functionDef.args, isStaticMethod(functionDef.decorator_list)), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = true, - lineAndColOf(functionDef) - ) - functionDefToMethod.put(functionDef, methodNode) - - val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) - - createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) - } - - private def isStaticMethod(decoratorList: Iterable[ast.iexpr]): Boolean = { - decoratorList.exists { - case name: ast.Name if name.id == "staticmethod" => true - case _ => false - } - } - - private def isClassMethod(decoratorList: Iterable[ast.iexpr]): Boolean = { - decoratorList.exists { - case name: ast.Name if name.id == "classmethod" => true - case _ => false - } - } - - private def createParameterProcessingFunction( - parameters: ast.Arguments, - isStatic: Boolean - ): () => MethodParameters = { - val startIndex = - if (contextStack.isClassContext && !isStatic) - 0 - else - 1 - - () => new MethodParameters(startIndex, convert(parameters, startIndex)) - } - - // TODO handle returns - private def createMethodAndMethodRef( - methodName: String, - scopeName: Option[String], - parameterProvider: () => MethodParameters, - bodyProvider: () => Iterable[nodes.NewNode], - returns: Option[ast.iexpr], - isAsync: Boolean, - lineAndColumn: LineAndColumn - ): (nodes.NewMethod, nodes.NewMethodRef) = { - val methodFullName = calculateFullNameFromContext(methodName) - - val methodRefNode = - nodeBuilder.methodRefNode("def " + methodName + "(...)", methodFullName, lineAndColumn) - - val methodNode = - createMethod( - methodName, - methodFullName, - scopeName, - parameterProvider, - bodyProvider, - returns, - isAsync = true, - Some(methodRefNode), - returnTypeHint = None, - lineAndColumn - ) - - (methodNode, methodRefNode) - } - - // It is important that the nodes returned by all provider function are created - // during the function invocation and not in advance. Because only - // than the context information is correct. - private def createMethod( - name: String, - fullName: String, - scopeName: Option[String], - parameterProvider: () => MethodParameters, - bodyProvider: () => Iterable[nodes.NewNode], - returns: Option[ast.iexpr], - isAsync: Boolean, - methodRefNode: Option[nodes.NewMethodRef], - returnTypeHint: Option[String], - lineAndColumn: LineAndColumn - ): nodes.NewMethod = { - val methodNode = nodeBuilder.methodNode(name, fullName, relFileName, lineAndColumn) - edgeBuilder.astEdge(methodNode, contextStack.astParent, contextStack.order.getAndInc) - - val blockNode = nodeBuilder.blockNode("", lineAndColumn) - edgeBuilder.astEdge(blockNode, methodNode, 1) - - contextStack.pushMethod(scopeName, methodNode, blockNode, methodRefNode) - - val virtualModifierNode = nodeBuilder.modifierNode(ModifierTypes.VIRTUAL) - edgeBuilder.astEdge(virtualModifierNode, methodNode, 0) - - val methodParameter = parameterProvider() - val parameterOrder = new AutoIncIndex(methodParameter.posStartIndex) - - methodParameter.positionalParams.foreach { parameterNode => - contextStack.addParameter(parameterNode) - edgeBuilder.astEdge(parameterNode, methodNode, parameterOrder.getAndInc) - } - - val methodReturnNode = - nodeBuilder.methodReturnNode(nodeBuilder.extractTypesFromHint(returns), returnTypeHint, lineAndColumn) - edgeBuilder.astEdge(methodReturnNode, methodNode, 2) - - val bodyOrder = new AutoIncIndex(1) - bodyProvider().foreach { bodyStmt => - edgeBuilder.astEdge(bodyStmt, blockNode, bodyOrder.getAndInc) - } - - // For every method we create a corresponding TYPE and TYPE_DECL and - // a binding for the method into TYPE_DECL. - val typeNode = nodeBuilder.typeNode(name, fullName) - val typeDeclNode = nodeBuilder.typeDeclNode(name, fullName, relFileName, Seq(Constants.ANY), lineAndColumn) - - // For every method that is a module, the local variables can be imported by other modules. This behaviour is - // much like fields so they are to be linked as fields to this method type - if (name == "") contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) - - contextStack.pop() - edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) - createBinding(methodNode, typeDeclNode) - - methodNode - } - - // For a classDef we do: - // 1. Create a metaType, metaTypeDecl and metaTypeRef. - // 2. Create a function containing the code of the classDef body. - // 3. Create a block which contains a call to the body function - // and an assignment of the metaTypeRef to an identifier with the class name. - // 4. Create type and typeDecl for the instance class. - // 5. Create and link members in metaTypeDecl and instanceTypeDecl - def convert(classDef: ast.ClassDef): NewNode = { - // Create type for the meta class object - val metaTypeDeclName = classDef.name + metaClassSuffix - val metaTypeDeclFullName = calculateFullNameFromContext(metaTypeDeclName) - - val metaTypeNode = nodeBuilder.typeNode(metaTypeDeclName, metaTypeDeclFullName) - val metaTypeDeclNode = - nodeBuilder.typeDeclNode( - metaTypeDeclName, - metaTypeDeclFullName, - relFileName, - Seq(Constants.ANY), - lineAndColOf(classDef) - ) - edgeBuilder.astEdge(metaTypeDeclNode, contextStack.astParent, contextStack.order.getAndInc) - - // Create type for class instances - val instanceTypeDeclName = classDef.name - val instanceTypeDeclFullName = calculateFullNameFromContext(instanceTypeDeclName) - - // TODO for now we just take the code of the base expression and pretend they are full names, converting special - // nodes as we go. - def handleInheritance(fs: List[ast.iexpr]): List[String] = fs match { - case (x: ast.Call) :: xs => - val node = convert(x) - val parent = contextStack.astParent - val tmpVar = createIdentifierNode(getUnusedName(), Store, lineAndColOf(x)) - val assignment = createAssignment(tmpVar, node, lineAndColOf(x)) - diffGraph.addEdge(parent, assignment, EdgeTypes.AST) - tmpVar.name +: handleInheritance(xs) - case x :: xs => - nodeToCode.getCode(x) +: handleInheritance(xs) - case Nil => Nil - } - - val inheritsFrom = handleInheritance(classDef.bases.toList) - - val instanceType = nodeBuilder.typeNode(instanceTypeDeclName, instanceTypeDeclFullName) - val instanceTypeDecl = - nodeBuilder.typeDeclNode( - instanceTypeDeclName, - instanceTypeDeclFullName, - relFileName, - inheritsFrom, - lineAndColOf(classDef) - ) - edgeBuilder.astEdge(instanceTypeDecl, contextStack.astParent, contextStack.order.getAndInc) - - // Create function which contains the code defining the class - contextStack.pushClass(Some(classDef.name), instanceTypeDecl) - val classBodyFunctionName = "" - val (_, methodRefNode) = createMethodAndMethodRef( - classBodyFunctionName, - scopeName = None, - parameterProvider = () => MethodParameters.empty(), - bodyProvider = () => classDef.body.map(convert), - None, - isAsync = false, - lineAndColOf(classDef) - ) + def convert(functionDef: ast.AsyncFunctionDef): NewNode = + val methodIdentifierNode = + createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) + val (methodNode, methodRefNode) = createMethodAndMethodRef( + functionDef.name, + Some(functionDef.name), + createParameterProcessingFunction( + functionDef.args, + isStaticMethod(functionDef.decorator_list) + ), + () => functionDef.body.map(convert), + functionDef.returns, + isAsync = true, + lineAndColOf(functionDef) + ) + functionDefToMethod.put(functionDef, methodNode) + + val wrappedMethodRefNode = + wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + + createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) + end convert + + private def isStaticMethod(decoratorList: Iterable[ast.iexpr]): Boolean = + decoratorList.exists { + case name: ast.Name if name.id == "staticmethod" => true + case _ => false + } + + private def isClassMethod(decoratorList: Iterable[ast.iexpr]): Boolean = + decoratorList.exists { + case name: ast.Name if name.id == "classmethod" => true + case _ => false + } + + private def createParameterProcessingFunction( + parameters: ast.Arguments, + isStatic: Boolean + ): () => MethodParameters = + val startIndex = + if contextStack.isClassContext && !isStatic then + 0 + else + 1 + + () => new MethodParameters(startIndex, convert(parameters, startIndex)) + + // TODO handle returns + private def createMethodAndMethodRef( + methodName: String, + scopeName: Option[String], + parameterProvider: () => MethodParameters, + bodyProvider: () => Iterable[nodes.NewNode], + returns: Option[ast.iexpr], + isAsync: Boolean, + lineAndColumn: LineAndColumn + ): (nodes.NewMethod, nodes.NewMethodRef) = + val methodFullName = calculateFullNameFromContext(methodName) + + val methodRefNode = + nodeBuilder.methodRefNode("def " + methodName + "(...)", methodFullName, lineAndColumn) + + val methodNode = + createMethod( + methodName, + methodFullName, + scopeName, + parameterProvider, + bodyProvider, + returns, + isAsync = true, + Some(methodRefNode), + returnTypeHint = None, + lineAndColumn + ) + + (methodNode, methodRefNode) + end createMethodAndMethodRef + + // It is important that the nodes returned by all provider function are created + // during the function invocation and not in advance. Because only + // than the context information is correct. + private def createMethod( + name: String, + fullName: String, + scopeName: Option[String], + parameterProvider: () => MethodParameters, + bodyProvider: () => Iterable[nodes.NewNode], + returns: Option[ast.iexpr], + isAsync: Boolean, + methodRefNode: Option[nodes.NewMethodRef], + returnTypeHint: Option[String], + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val methodNode = nodeBuilder.methodNode(name, fullName, relFileName, lineAndColumn) + edgeBuilder.astEdge(methodNode, contextStack.astParent, contextStack.order.getAndInc) + + val blockNode = nodeBuilder.blockNode("", lineAndColumn) + edgeBuilder.astEdge(blockNode, methodNode, 1) + + contextStack.pushMethod(scopeName, methodNode, blockNode, methodRefNode) + + val virtualModifierNode = nodeBuilder.modifierNode(ModifierTypes.VIRTUAL) + edgeBuilder.astEdge(virtualModifierNode, methodNode, 0) + + val methodParameter = parameterProvider() + val parameterOrder = new AutoIncIndex(methodParameter.posStartIndex) + + methodParameter.positionalParams.foreach { parameterNode => + contextStack.addParameter(parameterNode) + edgeBuilder.astEdge(parameterNode, methodNode, parameterOrder.getAndInc) + } - contextStack.pop() - - contextStack.pushClass(Some(classDef.name), metaTypeDeclNode) - - // Create meta class call handling method and bind it to meta class type. - val functions = classDef.body.collect { case func: ast.FunctionDef => func } - - // __init__ method has to be in functions because "async def __init__" is invalid. - val initFunctionOption = functions.find(_.name == "__init__") - - val initParameters = initFunctionOption.map(_.args).getOrElse { - // Create arguments of a default __init__ function. - ast.Arguments( - posonlyargs = mutable.Seq.empty[ast.Arg], - args = mutable.Seq(ast.Arg("self", None, None, classDef.attributeProvider)), - vararg = None, - kwonlyargs = mutable.Seq.empty[ast.Arg], - kw_defaults = mutable.Seq.empty[Option[ast.iexpr]], - kw_arg = None, - defaults = mutable.Seq.empty[ast.iexpr] - ) - } - - val metaClassCallHandlerMethod = - createMetaClassCallHandlerMethod(initParameters, metaTypeDeclName, metaTypeDeclFullName, instanceTypeDeclFullName) - - createBinding(metaClassCallHandlerMethod, metaTypeDeclNode) - - // Create fake __new__ regardless whether there is an actual implementation in the code. - // We do this to model the __init__ call in a visible way for the data flow tracker. - // This is done because very often the __init__ call is hidden in a super().__new__ call - // and we cant yet handle super(). - val fakeNewMethod = createFakeNewMethod(initParameters) - - val fakeNewMember = nodeBuilder.memberNode("", fakeNewMethod.fullName) - edgeBuilder.astEdge(fakeNewMember, metaTypeDeclNode, contextStack.order.getAndInc) - - // Create binding into class instance type for each method. - // Also create bindings into meta class type to enable calls like "MyClass.func(obj, p1)". - // For non static methods we create an adapter method which basically only shifts the parameters - // one to the left and makes sure that the meta class object is not passed to func as instance - // parameter. - classDef.body.foreach { - case func: ast.FunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode + val methodReturnNode = + nodeBuilder.methodReturnNode( + nodeBuilder.extractTypesFromHint(returns), + returnTypeHint, + lineAndColumn + ) + edgeBuilder.astEdge(methodReturnNode, methodNode, 2) + + val bodyOrder = new AutoIncIndex(1) + bodyProvider().foreach { bodyStmt => + edgeBuilder.astEdge(bodyStmt, blockNode, bodyOrder.getAndInc) + } + + // For every method we create a corresponding TYPE and TYPE_DECL and + // a binding for the method into TYPE_DECL. + val typeNode = nodeBuilder.typeNode(name, fullName) + val typeDeclNode = + nodeBuilder.typeDeclNode(name, fullName, relFileName, Seq(Constants.ANY), lineAndColumn) + + // For every method that is a module, the local variables can be imported by other modules. This behaviour is + // much like fields so they are to be linked as fields to this method type + if name == "" then contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) + + contextStack.pop() + edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) + createBinding(methodNode, typeDeclNode) + + methodNode + end createMethod + + // For a classDef we do: + // 1. Create a metaType, metaTypeDecl and metaTypeRef. + // 2. Create a function containing the code of the classDef body. + // 3. Create a block which contains a call to the body function + // and an assignment of the metaTypeRef to an identifier with the class name. + // 4. Create type and typeDecl for the instance class. + // 5. Create and link members in metaTypeDecl and instanceTypeDecl + def convert(classDef: ast.ClassDef): NewNode = + // Create type for the meta class object + val metaTypeDeclName = classDef.name + metaClassSuffix + val metaTypeDeclFullName = calculateFullNameFromContext(metaTypeDeclName) + + val metaTypeNode = nodeBuilder.typeNode(metaTypeDeclName, metaTypeDeclFullName) + val metaTypeDeclNode = + nodeBuilder.typeDeclNode( + metaTypeDeclName, + metaTypeDeclFullName, + relFileName, + Seq(Constants.ANY), + lineAndColOf(classDef) + ) + edgeBuilder.astEdge(metaTypeDeclNode, contextStack.astParent, contextStack.order.getAndInc) + + // Create type for class instances + val instanceTypeDeclName = classDef.name + val instanceTypeDeclFullName = calculateFullNameFromContext(instanceTypeDeclName) + + // TODO for now we just take the code of the base expression and pretend they are full names, converting special + // nodes as we go. + def handleInheritance(fs: List[ast.iexpr]): List[String] = fs match + case (x: ast.Call) :: xs => + val node = convert(x) + val parent = contextStack.astParent + val tmpVar = createIdentifierNode(getUnusedName(), Store, lineAndColOf(x)) + val assignment = createAssignment(tmpVar, node, lineAndColOf(x)) + diffGraph.addEdge(parent, assignment, EdgeTypes.AST) + tmpVar.name +: handleInheritance(xs) + case x :: xs => + nodeToCode.getCode(x) +: handleInheritance(xs) + case Nil => Nil + + val inheritsFrom = handleInheritance(classDef.bases.toList) + + val instanceType = nodeBuilder.typeNode(instanceTypeDeclName, instanceTypeDeclFullName) + val instanceTypeDecl = + nodeBuilder.typeDeclNode( + instanceTypeDeclName, + instanceTypeDeclFullName, + relFileName, + inheritsFrom, + lineAndColOf(classDef) + ) + edgeBuilder.astEdge(instanceTypeDecl, contextStack.astParent, contextStack.order.getAndInc) + + // Create function which contains the code defining the class + contextStack.pushClass(Some(classDef.name), instanceTypeDecl) + val classBodyFunctionName = "" + val (_, methodRefNode) = createMethodAndMethodRef( + classBodyFunctionName, + scopeName = None, + parameterProvider = () => MethodParameters.empty(), + bodyProvider = () => classDef.body.map(convert), + None, + isAsync = false, + lineAndColOf(classDef) ) - case func: ast.AsyncFunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode + + contextStack.pop() + + contextStack.pushClass(Some(classDef.name), metaTypeDeclNode) + + // Create meta class call handling method and bind it to meta class type. + val functions = classDef.body.collect { case func: ast.FunctionDef => func } + + // __init__ method has to be in functions because "async def __init__" is invalid. + val initFunctionOption = functions.find(_.name == "__init__") + + val initParameters = initFunctionOption.map(_.args).getOrElse { + // Create arguments of a default __init__ function. + ast.Arguments( + posonlyargs = mutable.Seq.empty[ast.Arg], + args = mutable.Seq(ast.Arg("self", None, None, classDef.attributeProvider)), + vararg = None, + kwonlyargs = mutable.Seq.empty[ast.Arg], + kw_defaults = mutable.Seq.empty[Option[ast.iexpr]], + kw_arg = None, + defaults = mutable.Seq.empty[ast.iexpr] + ) + } + + val metaClassCallHandlerMethod = + createMetaClassCallHandlerMethod( + initParameters, + metaTypeDeclName, + metaTypeDeclFullName, + instanceTypeDeclFullName + ) + + createBinding(metaClassCallHandlerMethod, metaTypeDeclNode) + + // Create fake __new__ regardless whether there is an actual implementation in the code. + // We do this to model the __init__ call in a visible way for the data flow tracker. + // This is done because very often the __init__ call is hidden in a super().__new__ call + // and we cant yet handle super(). + val fakeNewMethod = createFakeNewMethod(initParameters) + + val fakeNewMember = nodeBuilder.memberNode("", fakeNewMethod.fullName) + edgeBuilder.astEdge(fakeNewMember, metaTypeDeclNode, contextStack.order.getAndInc) + + // Create binding into class instance type for each method. + // Also create bindings into meta class type to enable calls like "MyClass.func(obj, p1)". + // For non static methods we create an adapter method which basically only shifts the parameters + // one to the left and makes sure that the meta class object is not passed to func as instance + // parameter. + classDef.body.foreach { + case func: ast.FunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + case func: ast.AsyncFunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + case _ => + // All other body statements are currently ignored. + } + + contextStack.pop() + + // Create call to function and assignment of the meta class object to a identifier named + // like the class. + val callToClassBodyFunction = + createCall(methodRefNode, "", lineAndColOf(classDef), Nil, Nil) + val metaTypeRefNode = + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColOf(classDef)) + val classIdentifierAssignNode = + createAssignmentToIdentifier(classDef.name, metaTypeRefNode, lineAndColOf(classDef)) + + val classBlock = createBlock( + callToClassBodyFunction :: classIdentifierAssignNode :: Nil, + lineAndColOf(classDef) ) - case _ => - // All other body statements are currently ignored. - } - - contextStack.pop() - - // Create call to function and assignment of the meta class object to a identifier named - // like the class. - val callToClassBodyFunction = createCall(methodRefNode, "", lineAndColOf(classDef), Nil, Nil) - val metaTypeRefNode = - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColOf(classDef)) - val classIdentifierAssignNode = - createAssignmentToIdentifier(classDef.name, metaTypeRefNode, lineAndColOf(classDef)) - - val classBlock = createBlock(callToClassBodyFunction :: classIdentifierAssignNode :: Nil, lineAndColOf(classDef)) - - classBlock - } - - private def createMemberBindingsAndAdapter( - function: ast.istmt, - functionName: String, - functionArgs: ast.Arguments, - functionDecoratorList: Iterable[ast.iexpr], - instanceTypeDecl: nodes.NewNode, - metaTypeDecl: nodes.NewNode - ): Unit = { - val memberForInstance = - nodeBuilder.memberNode(functionName, functionDefToMethod.apply(function).fullName, lineAndColOf(function)) - edgeBuilder.astEdge(memberForInstance, instanceTypeDecl, contextStack.order.getAndInc) - - val methodForMetaClass = - if (isStaticMethod(functionDecoratorList) || isClassMethod(functionDecoratorList)) { - functionDefToMethod.apply(function) - } else { - createMetaClassAdapterMethod( + + classBlock + end convert + + private def createMemberBindingsAndAdapter( + function: ast.istmt, + functionName: String, + functionArgs: ast.Arguments, + functionDecoratorList: Iterable[ast.iexpr], + instanceTypeDecl: nodes.NewNode, + metaTypeDecl: nodes.NewNode + ): Unit = + val memberForInstance = + nodeBuilder.memberNode( + functionName, + functionDefToMethod.apply(function).fullName, + lineAndColOf(function) + ) + edgeBuilder.astEdge(memberForInstance, instanceTypeDecl, contextStack.order.getAndInc) + + val methodForMetaClass = + if isStaticMethod(functionDecoratorList) || isClassMethod(functionDecoratorList) then + functionDefToMethod.apply(function) + else + createMetaClassAdapterMethod( + functionName, + functionDefToMethod.apply(function).fullName, + functionArgs, + lineAndColOf(function) + ) + + val memberForMeta = nodeBuilder.memberNode( functionName, - functionDefToMethod.apply(function).fullName, - functionArgs, + methodForMetaClass.fullName, lineAndColOf(function) ) - } - - val memberForMeta = nodeBuilder.memberNode(functionName, methodForMetaClass.fullName, lineAndColOf(function)) - edgeBuilder.astEdge(memberForMeta, metaTypeDecl, contextStack.order.getAndInc) - } - - /** Creates an adapter method which adapts the meta class version of a method to the instance class version. Consider - * class: class MyClass(): def func(self, p1): pass - * - * The syntax to call func via the meta class is: MyClass.func(someInstance, p1), whereas the call via the instance - * itself is: someInstance.func(p1). To adapt between those two we generate: def func(cls, self, - * p1): return STATIC_CALL(MyClass.func(self, p1)) - * @return - */ - // TODO handle kwArg - private def createMetaClassAdapterMethod( - adaptedMethodName: String, - adaptedMethodFullName: String, - parameters: ast.Arguments, - lineAndColumn: LineAndColumn - ): nodes.NewMethod = { - val adapterMethodName = adaptedMethodName + "" - val adapterMethodFullName = calculateFullNameFromContext(adapterMethodName) - - createMethod( - adapterMethodName, - adapterMethodFullName, - Some(adaptedMethodName), - parameterProvider = () => { - MethodParameters( - 0, - nodeBuilder.methodParameterNode("cls", isVariadic = false, lineAndColumn, Option(0)) :: Nil ++ - convert(parameters, 1) - ) - }, - bodyProvider = () => { - val (arguments, keywordArguments) = createArguments(parameters, lineAndColumn) - val staticCall = - createStaticCall(adaptedMethodName, adaptedMethodFullName, lineAndColumn, arguments, keywordArguments) - val returnNode = createReturn(Some(staticCall), None, lineAndColumn) - returnNode :: Nil - }, - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = None, - lineAndColumn - ) - } - - def createArguments( - arguments: ast.Arguments, - lineAndColumn: LineAndColumn - ): (Iterable[nodes.NewNode], Iterable[(String, nodes.NewNode)]) = { - val convertedArgs = mutable.ArrayBuffer.empty[nodes.NewNode] - val convertedKeywordArgs = mutable.ArrayBuffer.empty[(String, nodes.NewNode)] - - arguments.posonlyargs.foreach { arg => - convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) - } - arguments.args.foreach { arg => - convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) - } - arguments.vararg.foreach { arg => - convertedArgs.append( - createStarredUnpackOperatorCall(createIdentifierNode(arg.arg, Load, lineAndColumn), lineAndColumn) - ) - } - arguments.kwonlyargs.foreach { arg => - convertedKeywordArgs.append((arg.arg, createIdentifierNode(arg.arg, Load, lineAndColumn))) - } - - (convertedArgs, convertedKeywordArgs) - } - - /** This function strips the first positional parameter from initParameters, if present. - * @return - * Parameters without first positional parameter and adjusted line and column number information. - */ - private def stripFirstPositionalParameter(initParameters: ast.Arguments): (ast.Arguments, LineAndColumn) = { - if (initParameters.posonlyargs.nonEmpty) { - ( - initParameters.copy(posonlyargs = initParameters.posonlyargs.tail), - lineAndColOf(initParameters.posonlyargs.head) - ) - } else if (initParameters.args.nonEmpty) { - (initParameters.copy(args = initParameters.args.tail), lineAndColOf(initParameters.args.head)) - } else if (initParameters.vararg.nonEmpty) { - (initParameters, lineAndColOf(initParameters.vararg.get)) - } else { - (initParameters, lineAndColOf(initParameters.kw_arg.get)) - } - } - - /** Creates the method which handles a call to the meta class object. This process is also known as creating a new - * instance object, e.g. obj = MyClass(p1). The purpose of the generated function is to adapt between the special - * cased instance creation call and a normal call to __new__ (for now ). The adaption is required to in - * order to provide TYPE_REF(meta class) as instance argument to __new__/. So the - * looks like: def (p1): return DYNAMIC_CALL(receiver=TYPE_REF(meta class)., instance - * \= TYPE_REF(meta class), p1) - */ - // TODO handle kwArg - private def createMetaClassCallHandlerMethod( - initParameters: ast.Arguments, - metaTypeDeclName: String, - metaTypeDeclFullName: String, - instanceTypeDeclFullName: String - ): nodes.NewMethod = { - val methodName = "" - val methodFullName = calculateFullNameFromContext(methodName) - - // We need to drop the "self" parameter either from the position only or normal parameters - // because "self" is not passed through but rather created in __new__. - val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) - - createMethod( - methodName, - methodFullName, - Some(methodName), - parameterProvider = () => { - MethodParameters(1, convert(parametersWithoutSelf, 1)) - }, - bodyProvider = () => { - val (arguments, keywordArguments) = createArguments(parametersWithoutSelf, lineAndColumn) - - val fakeNewCall = createInstanceCall( - createFieldAccess( - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), - "", - lineAndColumn - ), - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), - "", - lineAndColumn, - arguments, - keywordArguments + edgeBuilder.astEdge(memberForMeta, metaTypeDecl, contextStack.order.getAndInc) + end createMemberBindingsAndAdapter + + /** Creates an adapter method which adapts the meta class version of a method to the instance + * class version. Consider class: class MyClass(): def func(self, p1): pass + * + * The syntax to call func via the meta class is: MyClass.func(someInstance, p1), whereas the + * call via the instance itself is: someInstance.func(p1). To adapt between those two we + * generate: def func(cls, self, p1): return STATIC_CALL(MyClass.func(self, + * p1)) + * @return + */ + // TODO handle kwArg + private def createMetaClassAdapterMethod( + adaptedMethodName: String, + adaptedMethodFullName: String, + parameters: ast.Arguments, + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val adapterMethodName = adaptedMethodName + "" + val adapterMethodFullName = calculateFullNameFromContext(adapterMethodName) + + createMethod( + adapterMethodName, + adapterMethodFullName, + Some(adaptedMethodName), + parameterProvider = () => + MethodParameters( + 0, + nodeBuilder.methodParameterNode( + "cls", + isVariadic = false, + lineAndColumn, + Option(0) + ) :: Nil ++ + convert(parameters, 1) + ), + bodyProvider = () => + val (arguments, keywordArguments) = createArguments(parameters, lineAndColumn) + val staticCall = + createStaticCall( + adaptedMethodName, + adaptedMethodFullName, + lineAndColumn, + arguments, + keywordArguments + ) + val returnNode = createReturn(Some(staticCall), None, lineAndColumn) + returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = None, + lineAndColumn ) + end createMetaClassAdapterMethod - val returnNode = createReturn(Some(fakeNewCall), None, lineAndColumn) + def createArguments( + arguments: ast.Arguments, + lineAndColumn: LineAndColumn + ): (Iterable[nodes.NewNode], Iterable[(String, nodes.NewNode)]) = + val convertedArgs = mutable.ArrayBuffer.empty[nodes.NewNode] + val convertedKeywordArgs = mutable.ArrayBuffer.empty[(String, nodes.NewNode)] - returnNode :: Nil - }, - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = Some(instanceTypeDeclFullName), - lineAndColumn - ) - } - - /** Creates a method which mimics the behaviour of a default __new__ method (the one you would get if no - * implementation is present). The reason we use a fake version of the __new__ method it that we wont be able to - * correctly track through most custom __new__ implementations as they usually call "super.__init__()" and we cannot - * yet handle "super". The fake __new__ looks like: def (cls, p1): __newInstance = - * STATIC_CALL(.alloc) cls.__init__(__newIstance, p1) return __newInstance - */ - // TODO handle kwArg - private def createFakeNewMethod(initParameters: ast.Arguments): nodes.NewMethod = { - val newMethodName = "" - val newMethodStubFullName = calculateFullNameFromContext(newMethodName) - - // We need to drop the "self" parameter either from the position only or normal parameters - // because "self" is not passed through but rather created in __new__. - val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) - - createMethod( - newMethodName, - newMethodStubFullName, - Some(newMethodName), - parameterProvider = () => { - MethodParameters( - 0, - nodeBuilder.methodParameterNode("cls", isVariadic = false, lineAndColumn, Some(0)) :: Nil ++ - convert(parametersWithoutSelf, 1) + arguments.posonlyargs.foreach { arg => + convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) + } + arguments.args.foreach { arg => + convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) + } + arguments.vararg.foreach { arg => + convertedArgs.append( + createStarredUnpackOperatorCall( + createIdentifierNode(arg.arg, Load, lineAndColumn), + lineAndColumn + ) + ) + } + arguments.kwonlyargs.foreach { arg => + convertedKeywordArgs.append(( + arg.arg, + createIdentifierNode(arg.arg, Load, lineAndColumn) + )) + } + + (convertedArgs, convertedKeywordArgs) + end createArguments + + /** This function strips the first positional parameter from initParameters, if present. + * @return + * Parameters without first positional parameter and adjusted line and column number + * information. + */ + private def stripFirstPositionalParameter(initParameters: ast.Arguments) + : (ast.Arguments, LineAndColumn) = + if initParameters.posonlyargs.nonEmpty then + ( + initParameters.copy(posonlyargs = initParameters.posonlyargs.tail), + lineAndColOf(initParameters.posonlyargs.head) + ) + else if initParameters.args.nonEmpty then + ( + initParameters.copy(args = initParameters.args.tail), + lineAndColOf(initParameters.args.head) + ) + else if initParameters.vararg.nonEmpty then + (initParameters, lineAndColOf(initParameters.vararg.get)) + else + (initParameters, lineAndColOf(initParameters.kw_arg.get)) + + /** Creates the method which handles a call to the meta class object. This process is also known + * as creating a new instance object, e.g. obj = MyClass(p1). The purpose of the generated + * function is to adapt between the special cased instance creation call and a normal call to + * __new__ (for now ). The adaption is required to in order to provide TYPE_REF(meta + * class) as instance argument to __new__/. So the looks like: + * def (p1): return DYNAMIC_CALL(receiver=TYPE_REF(meta class)., + * instance \= TYPE_REF(meta class), p1) + */ + // TODO handle kwArg + private def createMetaClassCallHandlerMethod( + initParameters: ast.Arguments, + metaTypeDeclName: String, + metaTypeDeclFullName: String, + instanceTypeDeclFullName: String + ): nodes.NewMethod = + val methodName = "" + val methodFullName = calculateFullNameFromContext(methodName) + + // We need to drop the "self" parameter either from the position only or normal parameters + // because "self" is not passed through but rather created in __new__. + val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) + + createMethod( + methodName, + methodFullName, + Some(methodName), + parameterProvider = () => + MethodParameters(1, convert(parametersWithoutSelf, 1)), + bodyProvider = () => + val (arguments, keywordArguments) = + createArguments(parametersWithoutSelf, lineAndColumn) + + val fakeNewCall = createInstanceCall( + createFieldAccess( + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), + "", + lineAndColumn + ), + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), + "", + lineAndColumn, + arguments, + keywordArguments + ) + + val returnNode = createReturn(Some(fakeNewCall), None, lineAndColumn) + + returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = Some(instanceTypeDeclFullName), + lineAndColumn ) - }, - bodyProvider = () => { - val allocatorCall = - createNAryOperatorCall(() => (".alloc", ".alloc"), Nil, lineAndColumn) - val assignmentToNewInstance = - createAssignment(createIdentifierNode("__newInstance", Store, lineAndColumn), allocatorCall, lineAndColumn) - - val (arguments, keywordArguments) = createArguments(parametersWithoutSelf, lineAndColumn) - val argumentWithInstance = mutable.ArrayBuffer.empty[nodes.NewNode] - argumentWithInstance.append(createIdentifierNode("__newInstance", Load, lineAndColumn)) - argumentWithInstance.appendAll(arguments) - - val initCall = createXDotYCall( - () => createIdentifierNode("cls", Load, lineAndColumn), - "__init__", - xMayHaveSideEffects = false, - lineAndColumn, - argumentWithInstance, - keywordArguments + end createMetaClassCallHandlerMethod + + /** Creates a method which mimics the behaviour of a default __new__ method (the one + * you would get if no implementation is present). The reason we use a fake version of the + * __new__ method it that we wont be able to correctly track through most custom __new__ + * implementations as they usually call "super.__init__()" and we cannot yet handle "super". + * The fake __new__ looks like: def (cls, p1): __newInstance = + * STATIC_CALL(.alloc) cls.__init__(__newIstance, p1) return __newInstance + */ + // TODO handle kwArg + private def createFakeNewMethod(initParameters: ast.Arguments): nodes.NewMethod = + val newMethodName = "" + val newMethodStubFullName = calculateFullNameFromContext(newMethodName) + + // We need to drop the "self" parameter either from the position only or normal parameters + // because "self" is not passed through but rather created in __new__. + val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) + + createMethod( + newMethodName, + newMethodStubFullName, + Some(newMethodName), + parameterProvider = () => + MethodParameters( + 0, + nodeBuilder.methodParameterNode( + "cls", + isVariadic = false, + lineAndColumn, + Some(0) + ) :: Nil ++ + convert(parametersWithoutSelf, 1) + ), + bodyProvider = () => + val allocatorCall = + createNAryOperatorCall( + () => (".alloc", ".alloc"), + Nil, + lineAndColumn + ) + val assignmentToNewInstance = + createAssignment( + createIdentifierNode("__newInstance", Store, lineAndColumn), + allocatorCall, + lineAndColumn + ) + + val (arguments, keywordArguments) = + createArguments(parametersWithoutSelf, lineAndColumn) + val argumentWithInstance = mutable.ArrayBuffer.empty[nodes.NewNode] + argumentWithInstance.append(createIdentifierNode( + "__newInstance", + Load, + lineAndColumn + )) + argumentWithInstance.appendAll(arguments) + + val initCall = createXDotYCall( + () => createIdentifierNode("cls", Load, lineAndColumn), + "__init__", + xMayHaveSideEffects = false, + lineAndColumn, + argumentWithInstance, + keywordArguments + ) + + val returnNode = + createReturn( + Some(createIdentifierNode("__newInstance", Load, lineAndColumn)), + None, + lineAndColumn + ) + + assignmentToNewInstance :: initCall :: returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = None, + lineAndColumn ) + end createFakeNewMethod - val returnNode = - createReturn(Some(createIdentifierNode("__newInstance", Load, lineAndColumn)), None, lineAndColumn) + def convert(ret: ast.Return): NewNode = + createReturn(ret.value.map(convert), Some(nodeToCode.getCode(ret)), lineAndColOf(ret)) - assignmentToNewInstance :: initCall :: returnNode :: Nil - }, - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = None, - lineAndColumn - ) - } - - def convert(ret: ast.Return): NewNode = { - createReturn(ret.value.map(convert), Some(nodeToCode.getCode(ret)), lineAndColOf(ret)) - } - - def convert(delete: ast.Delete): NewNode = { - val deleteArgs = delete.targets.map(convert) - - val code = "del " + deleteArgs.map(codeOf).mkString(", ") - val callNode = nodeBuilder.callNode(code, ".delete", DispatchTypes.STATIC_DISPATCH, lineAndColOf(delete)) - - addAstChildrenAsArguments(callNode, 1, deleteArgs) - callNode - } - - def convert(assign: ast.Assign): nodes.NewNode = { - val loweredNodes = - createValueToTargetsDecomposition(assign.targets, convert(assign.value), lineAndColOf(assign)) - - if (loweredNodes.size == 1) { - // Simple assignment can be returned directly. - loweredNodes.head - } else { - createBlock(loweredNodes, lineAndColOf(assign)) - } - } - - // TODO for now we ignore the annotation part and just emit the pure - // assignment. - def convert(annotatedAssign: ast.AnnAssign): NewNode = { - val targetNode = convert(annotatedAssign.target) - - annotatedAssign.value match { - case Some(value) => - val valueNode = convert(value) - createAssignment(targetNode, valueNode, lineAndColOf(annotatedAssign)) - case None => - // If there is no value, this is just an expr: annotation and since - // we for now ignore the annotation we emit just the expr because - // it may have side effects. - targetNode - } - } - - def convert(augAssign: ast.AugAssign): NewNode = { - val targetNode = convert(augAssign.target) - val valueNode = convert(augAssign.value) - - val (operatorCode, operatorFullName) = - augAssign.op match { - case ast.Add => ("+=", Operators.assignmentPlus) - case ast.Sub => ("-=", Operators.assignmentMinus) - case ast.Mult => ("*=", Operators.assignmentMultiplication) - case ast.MatMult => - ("@=", ".assignmentMatMult") // TODO make this a define and add policy for this - case ast.Div => ("/=", Operators.assignmentDivision) - case ast.Mod => ("%=", Operators.assignmentModulo) - case ast.Pow => ("**=", Operators.assignmentExponentiation) - case ast.LShift => ("<<=", Operators.assignmentShiftLeft) - case ast.RShift => ("<<=", Operators.assignmentArithmeticShiftRight) - case ast.BitOr => ("|=", Operators.assignmentOr) - case ast.BitXor => ("^=", Operators.assignmentXor) - case ast.BitAnd => ("&=", Operators.assignmentAnd) - case ast.FloorDiv => - ("//=", ".assignmentFloorDiv") // TODO make this a define and add policy for this - } - - createAugAssignment(targetNode, operatorCode, valueNode, operatorFullName, lineAndColOf(augAssign)) - } - - // TODO write test - def convert(forStmt: ast.For): NewNode = { - createForLowering( - forStmt.target, - forStmt.iter, - Iterable.empty, - forStmt.body.map(convert), - forStmt.orelse.map(convert), - isAsync = false, - lineAndColOf(forStmt) - ) - } - - def convert(forStmt: ast.AsyncFor): NewNode = { - createForLowering( - forStmt.target, - forStmt.iter, - Iterable.empty, - forStmt.body.map(convert), - forStmt.orelse.map(convert), - isAsync = true, - lineAndColOf(forStmt) - ) - } - - // Lowering of for x in y: : - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // (x = iterator.__next__()) - // - // } - // If one "if" is present the lowering of for x in y if z: - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // if (!z): continue - // (x = iterator.__next__()) - // - // } - // If multiple "ifs" are present the lowering of for x in y if z if a: ..,: - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // if (!(z and a)): continue - // (x = iterator.__next__()) - // - // } - protected def createForLowering( - target: ast.iexpr, - iter: ast.iexpr, - ifs: Iterable[ast.iexpr], - bodyNodes: Iterable[nodes.NewNode], - orelseNodes: Iterable[nodes.NewNode], - isAsync: Boolean, - lineAndColumn: LineAndColumn - ): nodes.NewNode = { - val iterVariableName = getUnusedName() - val iterExprIterCallNode = - createXDotYCall( - () => convert(iter), - "__iter__", - xMayHaveSideEffects = !iter.isInstanceOf[ast.Name], - lineAndColumn, - Nil, - Nil - ) - val iterAssignNode = - createAssignmentToIdentifier(iterVariableName, iterExprIterCallNode, lineAndColumn) - - val conditionNode = nodeBuilder.unknownNode("iteratorNonEmptyOrException", "", lineAndColumn) - - val controlStructureNode = - nodeBuilder.controlStructureNode("while ... : ...", ControlStructureTypes.WHILE, lineAndColumn) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - - val iterNextCallNode = - createXDotYCall( - () => createIdentifierNode(iterVariableName, Load, lineAndColumn), - "__next__", - xMayHaveSideEffects = false, - lineAndColumn, - Nil, - Nil - ) - - val loweredAssignNodes = - createValueToTargetsDecomposition(Iterable.single(target), iterNextCallNode, lineAndColumn) - - val blockStmtNodes = mutable.ArrayBuffer.empty[nodes.NewNode] - blockStmtNodes.appendAll(loweredAssignNodes) - - if (ifs.nonEmpty) { - val conditionNode = - if (ifs.size == 1) { - ifs.head - } else { - ast.BoolOp(ast.And, ifs.to(mutable.Seq), ifs.head.attributeProvider) - } - val ifNotContinueNode = convert( - ast.If( - ast.UnaryOp(ast.Not, conditionNode, ifs.head.attributeProvider), - mutable.ArrayBuffer.empty[ast.istmt].append(ast.Continue(ifs.head.attributeProvider)), - mutable.Seq.empty[ast.istmt], - ifs.head.attributeProvider + def convert(delete: ast.Delete): NewNode = + val deleteArgs = delete.targets.map(convert) + + val code = "del " + deleteArgs.map(codeOf).mkString(", ") + val callNode = nodeBuilder.callNode( + code, + ".delete", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(delete) ) - ) - - blockStmtNodes.append(ifNotContinueNode) - } - bodyNodes.foreach(blockStmtNodes.append) - - val bodyBlockNode = createBlock(blockStmtNodes, lineAndColumn) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - - if (orelseNodes.nonEmpty) { - val elseBlockNode = createBlock(orelseNodes, lineAndColumn) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) - } - - createBlock(iterAssignNode :: controlStructureNode :: Nil, lineAndColumn) - } - - def convert(astWhile: ast.While): nodes.NewNode = { - val conditionNode = convert(astWhile.test) - val bodyStmtNodes = astWhile.body.map(convert) - - val controlStructureNode = - nodeBuilder.controlStructureNode("while ... : ...", ControlStructureTypes.WHILE, lineAndColOf(astWhile)) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - - val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astWhile)) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - - if (astWhile.orelse.nonEmpty) { - val elseStmtNodes = astWhile.orelse.map(convert) - val elseBlockNode = - createBlock(elseStmtNodes, lineAndColOf(astWhile.orelse.head)) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) - } - - controlStructureNode - } - - def convert(astIf: ast.If): nodes.NewNode = { - val conditionNode = convert(astIf.test) - val bodyStmtNodes = astIf.body.map(convert) - - val controlStructureNode = - nodeBuilder.controlStructureNode("if ... : ...", ControlStructureTypes.IF, lineAndColOf(astIf)) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - - val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astIf)) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - - if (astIf.orelse.nonEmpty) { - val elseStmtNodes = astIf.orelse.map(convert) - val elseBlockNode = createBlock(elseStmtNodes, lineAndColOf(astIf.orelse.head)) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) - } - - controlStructureNode - } - - def convert(withStmt: ast.With): NewNode = { - val loweredNodes = - withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => - mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) - } - - loweredNodes.head - } - - def convert(withStmt: ast.AsyncWith): NewNode = { - val loweredNodes = - withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => - mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) - } - - loweredNodes.head - } - - // Handles the lowering of a single "with item". E.g. for: with A() as a, B() as b - // "A() as a" is a single "with item". - // The lowering for: - // with EXPRESSION as TARGET: - // SUITE - // is: - // manager = (EXPRESSION) - // enter = manager.__enter__ - // exit = manager.__exit__ - // value = enter(manager) - // - // try: - // TARGET = value - // SUITE - // finally: - // exit(manager) - // - // Note that this is not quite semantically correct because we ignore the exit method - // arguments and the exit call return value. This is fine for us because for our data - // flow tracking purposed we dont need that extra information and the AST is anyway - // "broken" because we are heavily lowering. - // For reference the following excerpt taken from the official Python 3.9.5 - // documentation found at - // https://docs.python.org/3/reference/compound_stmts.html#the-with-statement - // shows the semantically correct lowering: - // manager = (EXPRESSION) - // enter = type(manager).__enter__ - // exit = type(manager).__exit__ - // value = enter(manager) - // hit_except = False - // - // try: - // TARGET = value - // SUITE - // except: - // hit_except = True - // if not exit(manager, *sys.exc_info()): - // raise - // finally: - // if not hit_except: - // exit(manager, None, None, None) - private def convertWithItem(withItem: ast.Withitem, suite: collection.Seq[nodes.NewNode]): nodes.NewNode = { - val lineAndCol = lineAndColOf(withItem.context_expr) - val managerIdentifierName = getUnusedName("manager") - - val assignmentToManager = - createAssignmentToIdentifier(managerIdentifierName, convert(withItem.context_expr), lineAndCol) - - val enterIdentifierName = getUnusedName("enter") - val assignmentToEnter = createAssignmentToIdentifier( - enterIdentifierName, - createFieldAccess(createIdentifierNode(managerIdentifierName, Load, lineAndCol), "__enter__", lineAndCol), - lineAndCol - ) - val exitIdentifierName = getUnusedName("exit") - val assignmentToExit = createAssignmentToIdentifier( - exitIdentifierName, - createFieldAccess(createIdentifierNode(managerIdentifierName, Load, lineAndCol), "__exit__", lineAndCol), - lineAndCol - ) + addAstChildrenAsArguments(callNode, 1, deleteArgs) + callNode + + def convert(assign: ast.Assign): nodes.NewNode = + val loweredNodes = + createValueToTargetsDecomposition( + assign.targets, + convert(assign.value), + lineAndColOf(assign) + ) + + if loweredNodes.size == 1 then + // Simple assignment can be returned directly. + loweredNodes.head + else + createBlock(loweredNodes, lineAndColOf(assign)) + + // TODO for now we ignore the annotation part and just emit the pure + // assignment. + def convert(annotatedAssign: ast.AnnAssign): NewNode = + val targetNode = convert(annotatedAssign.target) + + annotatedAssign.value match + case Some(value) => + val valueNode = convert(value) + createAssignment(targetNode, valueNode, lineAndColOf(annotatedAssign)) + case None => + // If there is no value, this is just an expr: annotation and since + // we for now ignore the annotation we emit just the expr because + // it may have side effects. + targetNode + + def convert(augAssign: ast.AugAssign): NewNode = + val targetNode = convert(augAssign.target) + val valueNode = convert(augAssign.value) + + val (operatorCode, operatorFullName) = + augAssign.op match + case ast.Add => ("+=", Operators.assignmentPlus) + case ast.Sub => ("-=", Operators.assignmentMinus) + case ast.Mult => ("*=", Operators.assignmentMultiplication) + case ast.MatMult => + ( + "@=", + ".assignmentMatMult" + ) // TODO make this a define and add policy for this + case ast.Div => ("/=", Operators.assignmentDivision) + case ast.Mod => ("%=", Operators.assignmentModulo) + case ast.Pow => ("**=", Operators.assignmentExponentiation) + case ast.LShift => ("<<=", Operators.assignmentShiftLeft) + case ast.RShift => ("<<=", Operators.assignmentArithmeticShiftRight) + case ast.BitOr => ("|=", Operators.assignmentOr) + case ast.BitXor => ("^=", Operators.assignmentXor) + case ast.BitAnd => ("&=", Operators.assignmentAnd) + case ast.FloorDiv => + ( + "//=", + ".assignmentFloorDiv" + ) // TODO make this a define and add policy for this + + createAugAssignment( + targetNode, + operatorCode, + valueNode, + operatorFullName, + lineAndColOf(augAssign) + ) + end convert + + // TODO write test + def convert(forStmt: ast.For): NewNode = + createForLowering( + forStmt.target, + forStmt.iter, + Iterable.empty, + forStmt.body.map(convert), + forStmt.orelse.map(convert), + isAsync = false, + lineAndColOf(forStmt) + ) - val valueIdentifierName = getUnusedName("value") - val assignmentToValue = createAssignmentToIdentifier( - valueIdentifierName, - createInstanceCall( - createIdentifierNode(enterIdentifierName, Load, lineAndCol), - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "", - lineAndCol, - Nil, - Nil - ), - lineAndCol - ) + def convert(forStmt: ast.AsyncFor): NewNode = + createForLowering( + forStmt.target, + forStmt.iter, + Iterable.empty, + forStmt.body.map(convert), + forStmt.orelse.map(convert), + isAsync = true, + lineAndColOf(forStmt) + ) - val tryBody = - withItem.optional_vars match { - case Some(optionalVar) => - val loweredTargetAssignNodes = createValueToTargetsDecomposition( - withItem.optional_vars, - createIdentifierNode(valueIdentifierName, Load, lineAndCol), + // Lowering of for x in y: : + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // (x = iterator.__next__()) + // + // } + // If one "if" is present the lowering of for x in y if z: + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // if (!z): continue + // (x = iterator.__next__()) + // + // } + // If multiple "ifs" are present the lowering of for x in y if z if a: ..,: + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // if (!(z and a)): continue + // (x = iterator.__next__()) + // + // } + protected def createForLowering( + target: ast.iexpr, + iter: ast.iexpr, + ifs: Iterable[ast.iexpr], + bodyNodes: Iterable[nodes.NewNode], + orelseNodes: Iterable[nodes.NewNode], + isAsync: Boolean, + lineAndColumn: LineAndColumn + ): nodes.NewNode = + val iterVariableName = getUnusedName() + val iterExprIterCallNode = + createXDotYCall( + () => convert(iter), + "__iter__", + xMayHaveSideEffects = !iter.isInstanceOf[ast.Name], + lineAndColumn, + Nil, + Nil + ) + val iterAssignNode = + createAssignmentToIdentifier(iterVariableName, iterExprIterCallNode, lineAndColumn) + + val conditionNode = + nodeBuilder.unknownNode("iteratorNonEmptyOrException", "", lineAndColumn) + + val controlStructureNode = + nodeBuilder.controlStructureNode( + "while ... : ...", + ControlStructureTypes.WHILE, + lineAndColumn + ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) + + val iterNextCallNode = + createXDotYCall( + () => createIdentifierNode(iterVariableName, Load, lineAndColumn), + "__next__", + xMayHaveSideEffects = false, + lineAndColumn, + Nil, + Nil + ) + + val loweredAssignNodes = + createValueToTargetsDecomposition( + Iterable.single(target), + iterNextCallNode, + lineAndColumn + ) + + val blockStmtNodes = mutable.ArrayBuffer.empty[nodes.NewNode] + blockStmtNodes.appendAll(loweredAssignNodes) + + if ifs.nonEmpty then + val conditionNode = + if ifs.size == 1 then + ifs.head + else + ast.BoolOp(ast.And, ifs.to(mutable.Seq), ifs.head.attributeProvider) + val ifNotContinueNode = convert( + ast.If( + ast.UnaryOp(ast.Not, conditionNode, ifs.head.attributeProvider), + mutable.ArrayBuffer.empty[ast.istmt].append( + ast.Continue(ifs.head.attributeProvider) + ), + mutable.Seq.empty[ast.istmt], + ifs.head.attributeProvider + ) + ) + + blockStmtNodes.append(ifNotContinueNode) + bodyNodes.foreach(blockStmtNodes.append) + + val bodyBlockNode = createBlock(blockStmtNodes, lineAndColumn) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) + + if orelseNodes.nonEmpty then + val elseBlockNode = createBlock(orelseNodes, lineAndColumn) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) + + createBlock(iterAssignNode :: controlStructureNode :: Nil, lineAndColumn) + end createForLowering + + def convert(astWhile: ast.While): nodes.NewNode = + val conditionNode = convert(astWhile.test) + val bodyStmtNodes = astWhile.body.map(convert) + + val controlStructureNode = + nodeBuilder.controlStructureNode( + "while ... : ...", + ControlStructureTypes.WHILE, + lineAndColOf(astWhile) + ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) + + val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astWhile)) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) + + if astWhile.orelse.nonEmpty then + val elseStmtNodes = astWhile.orelse.map(convert) + val elseBlockNode = + createBlock(elseStmtNodes, lineAndColOf(astWhile.orelse.head)) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) + + controlStructureNode + end convert + + def convert(astIf: ast.If): nodes.NewNode = + val conditionNode = convert(astIf.test) + val bodyStmtNodes = astIf.body.map(convert) + + val controlStructureNode = + nodeBuilder.controlStructureNode( + "if ... : ...", + ControlStructureTypes.IF, + lineAndColOf(astIf) + ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) + + val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astIf)) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) + + if astIf.orelse.nonEmpty then + val elseStmtNodes = astIf.orelse.map(convert) + val elseBlockNode = createBlock(elseStmtNodes, lineAndColOf(astIf.orelse.head)) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) + + controlStructureNode + end convert + + def convert(withStmt: ast.With): NewNode = + val loweredNodes = + withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => + mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) + } + + loweredNodes.head + + def convert(withStmt: ast.AsyncWith): NewNode = + val loweredNodes = + withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => + mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) + } + + loweredNodes.head + + // Handles the lowering of a single "with item". E.g. for: with A() as a, B() as b + // "A() as a" is a single "with item". + // The lowering for: + // with EXPRESSION as TARGET: + // SUITE + // is: + // manager = (EXPRESSION) + // enter = manager.__enter__ + // exit = manager.__exit__ + // value = enter(manager) + // + // try: + // TARGET = value + // SUITE + // finally: + // exit(manager) + // + // Note that this is not quite semantically correct because we ignore the exit method + // arguments and the exit call return value. This is fine for us because for our data + // flow tracking purposed we dont need that extra information and the AST is anyway + // "broken" because we are heavily lowering. + // For reference the following excerpt taken from the official Python 3.9.5 + // documentation found at + // https://docs.python.org/3/reference/compound_stmts.html#the-with-statement + // shows the semantically correct lowering: + // manager = (EXPRESSION) + // enter = type(manager).__enter__ + // exit = type(manager).__exit__ + // value = enter(manager) + // hit_except = False + // + // try: + // TARGET = value + // SUITE + // except: + // hit_except = True + // if not exit(manager, *sys.exc_info()): + // raise + // finally: + // if not hit_except: + // exit(manager, None, None, None) + private def convertWithItem( + withItem: ast.Withitem, + suite: collection.Seq[nodes.NewNode] + ): nodes.NewNode = + val lineAndCol = lineAndColOf(withItem.context_expr) + val managerIdentifierName = getUnusedName("manager") + + val assignmentToManager = + createAssignmentToIdentifier( + managerIdentifierName, + convert(withItem.context_expr), + lineAndCol + ) + + val enterIdentifierName = getUnusedName("enter") + val assignmentToEnter = createAssignmentToIdentifier( + enterIdentifierName, + createFieldAccess( + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "__enter__", lineAndCol - ) - - loweredTargetAssignNodes ++ suite - case None => - suite - } - - // TODO For the except handler we currently lower as: - // hit_except = True - // exit(manager) - // instead of: - // hit_except = True - // if not exit(manager, *sys.exc_info()): - // raise - - val finalBlockStmts = - createInstanceCall( - createIdentifierNode("__exit__", Load, lineAndCol), - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "", - lineAndCol, - Nil, - Nil - ) :: Nil - - val tryBlock = createTry(tryBody, Nil, finalBlockStmts, Nil, lineAndCol) - - val blockStmts = mutable.ArrayBuffer.empty[nodes.NewNode] - blockStmts.append(assignmentToManager) - blockStmts.append(assignmentToEnter) - blockStmts.append(assignmentToExit) - blockStmts.append(assignmentToValue) - blockStmts.append(tryBlock) - - createBlock(blockStmts, lineAndCol) - } - - // TODO add case pattern and guard statements to cpg - def convert(matchStmt: ast.Match): NewNode = { - val controlStructureNode = - nodeBuilder.controlStructureNode("match ... : ...", ControlStructureTypes.SWITCH, lineAndColOf(matchStmt)) - - val matchSubject = convert(matchStmt.subject) - - val caseBlocks = matchStmt.cases.map { caseStmt => - val bodyNodes = caseStmt.body.map(convert) - createBlock(bodyNodes, lineAndColOf(caseStmt.pattern)) - } - - edgeBuilder.conditionEdge(matchSubject, controlStructureNode) - addAstChildNodes(controlStructureNode, 1, matchSubject) - addAstChildNodes(controlStructureNode, 2, caseBlocks) - - controlStructureNode - } - - def convert(raise: ast.Raise): NewNode = { - val excNodeOption = raise.exc.map(convert) - val causeNodeOption = raise.cause.map(convert) - - val args = mutable.ArrayBuffer.empty[nodes.NewNode] - args.appendAll(excNodeOption) - args.appendAll(causeNodeOption) - - val code = "raise" + - excNodeOption.map(excNode => " " + codeOf(excNode)).getOrElse("") + - causeNodeOption.map(causeNode => " from " + codeOf(causeNode)).getOrElse("") - - val callNode = nodeBuilder.callNode(code, ".raise", DispatchTypes.STATIC_DISPATCH, lineAndColOf(raise)) - - addAstChildrenAsArguments(callNode, 1, args) - - callNode - } - - def convert(tryStmt: ast.Try): NewNode = { - createTry( - tryStmt.body.map(convert), - tryStmt.handlers.map(convert), - tryStmt.finalbody.map(convert), - tryStmt.orelse.map(convert), - lineAndColOf(tryStmt) - ) - } - - def convert(assert: ast.Assert): NewNode = { - val testNode = convert(assert.test) - val msgNode = assert.msg.map(convert) - - val code = "assert " + codeOf(testNode) + msgNode.map(m => ", " + codeOf(m)).getOrElse("") - val callNode = nodeBuilder.callNode(code, ".assert", DispatchTypes.STATIC_DISPATCH, lineAndColOf(assert)) - - addAstChildrenAsArguments(callNode, 1, testNode) - if (msgNode.isDefined) { - addAstChildrenAsArguments(callNode, 2, msgNode) - } - callNode - } - - // Lowering of import x: - // x = import("", "x") - // Lowering of import x as y: - // y = import("", "x") - // Lowering of import x, y: - // { - // x = import("", "x") - // y = import("", "y") - // } - def convert(importStmt: ast.Import): NewNode = { - createTransformedImport("", importStmt.names, lineAndColOf(importStmt)) - } - - // Lowering of from x import y: - // y = import("x", "y") - // Lowering of from x import y as z: - // z = import("x", "y") - // Lowering of from x import y, z: - // { - // y = import("x", "y") - // z = import("x", "z") - // } - def convert(importFrom: ast.ImportFrom): NewNode = { - var moduleName = "" - - for (i <- 0 until importFrom.level) { - moduleName = moduleName.appended('.') - } - moduleName += importFrom.module.getOrElse("") - - createTransformedImport(moduleName, importFrom.names, lineAndColOf(importFrom)) - } - - def convert(global: ast.Global): NewNode = { - global.names.foreach(contextStack.addGlobalVariable) - val code = global.names.mkString("global ", ", ", "") - nodeBuilder.unknownNode(code, global.getClass.getName, lineAndColOf(global)) - } - - def convert(nonLocal: ast.Nonlocal): NewNode = { - nonLocal.names.foreach(contextStack.addNonLocalVariable) - val code = nonLocal.names.mkString("nonlocal ", ", ", "") - nodeBuilder.unknownNode(code, nonLocal.getClass.getName, lineAndColOf(nonLocal)) - } - - def convert(expr: ast.Expr): nodes.NewNode = { - convert(expr.value) - } - - def convert(pass: ast.Pass): nodes.NewNode = { - nodeBuilder.callNode("pass", ".pass", DispatchTypes.STATIC_DISPATCH, lineAndColOf(pass)) - } - - def convert(astBreak: ast.Break): nodes.NewNode = { - nodeBuilder.controlStructureNode("break", ControlStructureTypes.BREAK, lineAndColOf(astBreak)) - } - - def convert(astContinue: ast.Continue): nodes.NewNode = { - nodeBuilder.controlStructureNode("continue", ControlStructureTypes.CONTINUE, lineAndColOf(astContinue)) - } - - def convert(raise: ast.RaiseP2): NewNode = ??? - - def convert(errorStatement: ast.ErrorStatement): NewNode = { - nodeBuilder.unknownNode(errorStatement.toString, errorStatement.getClass.getName, lineAndColOf(errorStatement)) - } - - def convert(expr: ast.iexpr): NewNode = { - expr match { - case node: ast.BoolOp => convert(node) - case node: ast.NamedExpr => convert(node) - case node: ast.BinOp => convert(node) - case node: ast.UnaryOp => convert(node) - case node: ast.Lambda => convert(node) - case node: ast.IfExp => convert(node) - case node: ast.Dict => convert(node) - case node: ast.Set => convert(node) - case node: ast.ListComp => convert(node) - case node: ast.SetComp => convert(node) - case node: ast.DictComp => convert(node) - case node: ast.GeneratorExp => convert(node) - case node: ast.Await => convert(node) - case node: ast.Yield => unhandled(node) - case node: ast.YieldFrom => unhandled(node) - case node: ast.Compare => convert(node) - case node: ast.Call => convert(node) - case node: ast.FormattedValue => convert(node) - case node: ast.JoinedString => convert(node) - case node: ast.Constant => convert(node) - case node: ast.Attribute => convert(node) - case node: ast.Subscript => convert(node) - case node: ast.Starred => convert(node) - case node: ast.Name => convert(node) - case node: ast.List => convert(node) - case node: ast.Tuple => convert(node) - case node: ast.Slice => unhandled(node) - case node: ast.StringExpList => convert(node) - } - } - - def convert(boolOp: ast.BoolOp): nodes.NewNode = { - def boolOpToCodeAndFullName(operator: ast.iboolop): () => (String, String) = { () => - { - operator match { - case ast.And => ("and", Operators.logicalAnd) - case ast.Or => ("or", Operators.logicalOr) - } - } - } - - val operandNodes = boolOp.values.map(convert) - createNAryOperatorCall(boolOpToCodeAndFullName(boolOp.op), operandNodes, lineAndColOf(boolOp)) - } - - // TODO test - def convert(namedExpr: ast.NamedExpr): NewNode = { - val targetNode = convert(namedExpr.target) - val valueNode = convert(namedExpr.value) - - createAssignment(targetNode, valueNode, lineAndColOf(namedExpr)) - } - - def convert(binOp: ast.BinOp): nodes.NewNode = { - val lhsNode = convert(binOp.left) - val rhsNode = convert(binOp.right) - - val opCodeAndFullName = - binOp.op match { - case ast.Add => ("+", Operators.addition) - case ast.Sub => ("-", Operators.subtraction) - case ast.Mult => ("*", Operators.multiplication) - case ast.MatMult => - ("@", ".matMult") // TODO make this a define and add policy for this - case ast.Div => ("/", Operators.division) - case ast.Mod => ("%", Operators.modulo) - case ast.Pow => ("**", Operators.exponentiation) - case ast.LShift => ("<<", Operators.shiftLeft) - case ast.RShift => (">>", Operators.arithmeticShiftRight) - case ast.BitOr => ("|", Operators.or) - case ast.BitXor => ("^", Operators.xor) - case ast.BitAnd => ("&", Operators.and) - case ast.FloorDiv => - ("//", ".floorDiv") // TODO make this a define and add policy for this - } - - createBinaryOperatorCall(lhsNode, () => opCodeAndFullName, rhsNode, lineAndColOf(binOp)) - } - - def convert(unaryOp: ast.UnaryOp): nodes.NewNode = { - val operandNode = convert(unaryOp.operand) - - val (operatorCode, methodFullName) = - unaryOp.op match { - case ast.Invert => ("~", Operators.not) - case ast.Not => ("not ", Operators.logicalNot) - case ast.UAdd => ("+", Operators.plus) - case ast.USub => ("-", Operators.minus) - } - - val code = operatorCode + codeOf(operandNode) - val callNode = nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColOf(unaryOp)) - - addAstChildrenAsArguments(callNode, 1, operandNode) - - callNode - } - - def convert(lambda: ast.Lambda): NewNode = { - // TODO test lambda expression. - val lambdaCounter = contextStack.getAndIncLambdaCounter() - val lambdaNumberSuffix = - if (lambdaCounter == 0) { - "" - } else { - lambdaCounter.toString - } - - val name = "" + lambdaNumberSuffix - val (_, methodRefNode) = createMethodAndMethodRef( - name, - Some(name), - createParameterProcessingFunction(lambda.args, isStatic = false), - () => Iterable.single(convert(new ast.Return(lambda.body, lambda.attributeProvider))), - returns = None, - isAsync = false, - lineAndColOf(lambda) - ) - methodRefNode - } - - // TODO test - def convert(ifExp: ast.IfExp): NewNode = { - val bodyNode = convert(ifExp.body) - val testNode = convert(ifExp.test) - val orElseNode = convert(ifExp.orelse) - - val code = codeOf(bodyNode) + " if " + codeOf(testNode) + " else " + codeOf(orElseNode) - val callNode = nodeBuilder.callNode(code, Operators.conditional, DispatchTypes.STATIC_DISPATCH, lineAndColOf(ifExp)) - - // testNode is first argument to match semantics of Operators.conditional. - addAstChildrenAsArguments(callNode, 1, testNode, bodyNode, orElseNode) - - callNode - } - - /** Lowering of {x:1, y:2, **z}: { tmp = {} tmp[x] = 1 tmp[y] = 2 tmp.update(z) tmp } - */ - // TODO test - def convert(dict: ast.Dict): NewNode = { - val tmpVariableName = getUnusedName() - val dictOperatorCall = - createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dict)) - val dictVariableAssigNode = - createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dict)) - - val dictElementAssignNodes = dict.keys.zip(dict.values).map { case (key, value) => - key match { - case Some(key) => - val indexAccessNode = createIndexAccess( - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), - convert(key), - lineAndColOf(dict) - ) - - createAssignment(indexAccessNode, convert(value), lineAndColOf(dict)) - case None => - createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), - "update", - xMayHaveSideEffects = false, - lineAndColOf(dict), - convert(value) :: Nil, + ), + lineAndCol + ) + + val exitIdentifierName = getUnusedName("exit") + val assignmentToExit = createAssignmentToIdentifier( + exitIdentifierName, + createFieldAccess( + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "__exit__", + lineAndCol + ), + lineAndCol + ) + + val valueIdentifierName = getUnusedName("value") + val assignmentToValue = createAssignmentToIdentifier( + valueIdentifierName, + createInstanceCall( + createIdentifierNode(enterIdentifierName, Load, lineAndCol), + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "", + lineAndCol, + Nil, Nil - ) - } - } - - val dictInstanceReturnIdentifierNode = - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)) - - val blockElements = mutable.ArrayBuffer.empty[nodes.NewNode] - blockElements.append(dictVariableAssigNode) - blockElements.appendAll(dictElementAssignNodes) - blockElements.append(dictInstanceReturnIdentifierNode) - createBlock(blockElements, lineAndColOf(dict)) - } - - // TODO test - def convert(set: ast.Set): nodes.NewNode = { - val setElementNodes = set.elts.map(convert) - val code = setElementNodes.map(codeOf).mkString("{", ", ", "}") - - val callNode = nodeBuilder.callNode(code, ".setLiteral", DispatchTypes.STATIC_DISPATCH, lineAndColOf(set)) - - addAstChildrenAsArguments(callNode, 1, setElementNodes) - - callNode - } - - /** Lowering of [x for y in l for x in y]: { tmp = [] ( for y in l: for x in y: tmp.append(x) ) tmp } - */ - // TODO test - def convert(listComp: ast.ListComp): NewNode = { - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - // Create tmp = list() - val listOperatorCall = - createLiteralOperatorCall("[", "]", ".listLiteral", lineAndColOf(listComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, listOperatorCall, lineAndColOf(listComp)) - - // Create tmp.append(x) - val listVarAppendCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(listComp)), - "append", - xMayHaveSideEffects = false, - lineAndColOf(listComp), - convert(listComp.elt) :: Nil, - Nil - ) + ), + lineAndCol + ) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - listVarAppendCallNode, - listComp.generators, - lineAndColOf(listComp) - ) + val tryBody = + withItem.optional_vars match + case Some(optionalVar) => + val loweredTargetAssignNodes = createValueToTargetsDecomposition( + withItem.optional_vars, + createIdentifierNode(valueIdentifierName, Load, lineAndCol), + lineAndCol + ) + + loweredTargetAssignNodes ++ suite + case None => + suite + + // TODO For the except handler we currently lower as: + // hit_except = True + // exit(manager) + // instead of: + // hit_except = True + // if not exit(manager, *sys.exc_info()): + // raise + + val finalBlockStmts = + createInstanceCall( + createIdentifierNode("__exit__", Load, lineAndCol), + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "", + lineAndCol, + Nil, + Nil + ) :: Nil + + val tryBlock = createTry(tryBody, Nil, finalBlockStmts, Nil, lineAndCol) + + val blockStmts = mutable.ArrayBuffer.empty[nodes.NewNode] + blockStmts.append(assignmentToManager) + blockStmts.append(assignmentToEnter) + blockStmts.append(assignmentToExit) + blockStmts.append(assignmentToValue) + blockStmts.append(tryBlock) + + createBlock(blockStmts, lineAndCol) + end convertWithItem + + // TODO add case pattern and guard statements to cpg + def convert(matchStmt: ast.Match): NewNode = + val controlStructureNode = + nodeBuilder.controlStructureNode( + "match ... : ...", + ControlStructureTypes.SWITCH, + lineAndColOf(matchStmt) + ) + + val matchSubject = convert(matchStmt.subject) + + val caseBlocks = matchStmt.cases.map { caseStmt => + val bodyNodes = caseStmt.body.map(convert) + createBlock(bodyNodes, lineAndColOf(caseStmt.pattern)) + } - contextStack.pop() - - comprehensionBlockNode - } - - /** Lowering of {x for y in l for x in y}: { tmp = {} ( for y in l: for x in y: tmp.add(x) ) tmp } - */ - // TODO test - def convert(setComp: ast.SetComp): NewNode = { - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - val setOperatorCall = - createLiteralOperatorCall("{", "}", ".setLiteral", lineAndColOf(setComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, setOperatorCall, lineAndColOf(setComp)) - - // Create tmp.add(x) - val setVarAddCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(setComp)), - "add", - xMayHaveSideEffects = false, - lineAndColOf(setComp), - convert(setComp.elt) :: Nil, - Nil - ) + edgeBuilder.conditionEdge(matchSubject, controlStructureNode) + addAstChildNodes(controlStructureNode, 1, matchSubject) + addAstChildNodes(controlStructureNode, 2, caseBlocks) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - setVarAddCallNode, - setComp.generators, - lineAndColOf(setComp) - ) + controlStructureNode + end convert - contextStack.pop() - - comprehensionBlockNode - } - - /** Lowering of {k:v for y in l for k, v in y}: { tmp = {} ( for y in l: for k, v in y: tmp[k] = v ) tmp } - */ - // TODO test - def convert(dictComp: ast.DictComp): NewNode = { - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - val dictOperatorCall = - createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dictComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dictComp)) - - // Create tmp[k] = v - val dictAssigNode = createAssignment( - createIndexAccess( - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dictComp)), - convert(dictComp.key), - lineAndColOf(dictComp) - ), - convert(dictComp.value), - lineAndColOf(dictComp) - ) + def convert(raise: ast.Raise): NewNode = + val excNodeOption = raise.exc.map(convert) + val causeNodeOption = raise.cause.map(convert) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - dictAssigNode, - dictComp.generators, - lineAndColOf(dictComp) - ) + val args = mutable.ArrayBuffer.empty[nodes.NewNode] + args.appendAll(excNodeOption) + args.appendAll(causeNodeOption) - contextStack.pop() - - comprehensionBlockNode - } - - /** Lowering of (x for y in l for x in y): { tmp = .genExp ( for y in l: for x in y: - * tmp.append(x) ) tmp } This lowering is not quite correct as it ignores the lazy evaluation of the generator - * expression. Instead it just mimics the list comprehension lowering but for now this is good enough. - */ - // TODO test - def convert(generatorExp: ast.GeneratorExp): NewNode = { - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - // Create tmp = list() - val genExpOperatorCall = - nodeBuilder.callNode( - ".genExp", - ".genExp", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(generatorExp) - ) - - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, genExpOperatorCall, lineAndColOf(generatorExp)) - - // Create tmp.append(x) - val genExpAppendCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(generatorExp)), - "append", - xMayHaveSideEffects = false, - lineAndColOf(generatorExp), - convert(generatorExp.elt) :: Nil, - Nil - ) + val code = "raise" + + excNodeOption.map(excNode => " " + codeOf(excNode)).getOrElse("") + + causeNodeOption.map(causeNode => " from " + codeOf(causeNode)).getOrElse("") - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - genExpAppendCallNode, - generatorExp.generators, - lineAndColOf(generatorExp) - ) + val callNode = nodeBuilder.callNode( + code, + ".raise", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(raise) + ) + + addAstChildrenAsArguments(callNode, 1, args) - contextStack.pop() - - comprehensionBlockNode - } - - def convert(await: ast.Await): NewNode = { - // Since the CPG format does not provide means to model async/await, - // we for now treat it as non existing. - convert(await.value) - } - - def convert(yieldExpr: ast.Yield): NewNode = ??? - - def convert(yieldFrom: ast.YieldFrom): NewNode = ??? - - // In case of a single compare operation there is no lowering applied. - // So e.g. x < y stay untouched. - // Otherwise the lowering is as follows: - // Src AST: - // x < y < z < a - // Lowering: - // { - // tmp1 = y - // x < tmp1 && { - // tmp2 = z - // tmp1 < tmp2 && { - // tmp2 < a - // } - // } - // } - def convert(compare: ast.Compare): NewNode = { - assert(compare.ops.size == compare.comparators.size) - var lhsNode = convert(compare.left) - - val topLevelExprNodes = - lowerComparatorChain(lhsNode, compare.ops, compare.comparators, lineAndColOf(compare)) - if (topLevelExprNodes.size > 1) { - createBlock(topLevelExprNodes, lineAndColOf(compare)) - } else { - topLevelExprNodes.head - } - } - - private def compopToOpCodeAndFullName(compareOp: ast.icompop): () => (String, String) = { () => - { - compareOp match { - case ast.Eq => ("==", Operators.equals) - case ast.NotEq => ("!=", Operators.notEquals) - case ast.Lt => ("<", Operators.lessThan) - case ast.LtE => ("<=", Operators.lessEqualsThan) - case ast.Gt => (">", Operators.greaterThan) - case ast.GtE => (">=", Operators.greaterEqualsThan) - case ast.Is => ("is", ".is") - case ast.IsNot => ("is not", ".isNot") - case ast.In => ("in", ".in") - case ast.NotIn => ("not in", ".notIn") - } - } - } - - def lowerComparatorChain( - lhsNode: nodes.NewNode, - compOperators: Iterable[ast.icompop], - comparators: Iterable[ast.iexpr], - lineAndColumn: LineAndColumn - ): Iterable[nodes.NewNode] = { - val rhsNode = convert(comparators.head) - - if (compOperators.size == 1) { - val compareNode = - createBinaryOperatorCall(lhsNode, compopToOpCodeAndFullName(compOperators.head), rhsNode, lineAndColumn) - Iterable.single(compareNode) - } else { - val tmpVariableName = getUnusedName() - val assignmentNode = createAssignmentToIdentifier(tmpVariableName, rhsNode, lineAndColumn) - - val tmpIdentifierCompare1 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val compareNode = createBinaryOperatorCall( - lhsNode, - compopToOpCodeAndFullName(compOperators.head), - tmpIdentifierCompare1, - lineAndColumn - ) - - val tmpIdentifierCompare2 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val childNodes = lowerComparatorChain(tmpIdentifierCompare2, compOperators.tail, comparators.tail, lineAndColumn) - - val blockNode = createBlock(childNodes, lineAndColumn) - - Iterable(assignmentNode, createBinaryOperatorCall(compareNode, andOpCodeAndFullName(), blockNode, lineAndColumn)) - } - } - - private def andOpCodeAndFullName(): () => (String, String) = { () => - ("and", Operators.logicalAnd) - } - - /** TODO For now this function compromises on the correctness of the lowering in order to get some data flow tracking - * going. - * 1. For constructs like x.func() we assume x to be the instance which is passed into func. This is not true since - * the instance method object gets the instance already bound/captured during function access. This becomes - * relevant for constructs like: x.func = y.func <- y.func is class method object x.func() In this case the - * instance passed into func is y and not x. We cannot represent this in th CPG and thus stick to the assumption - * that the part before the "." and the bound/captured instance will be the same. For reference see: - * https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy search for "Instance methods" - */ - def convert(call: ast.Call): nodes.NewNode = { - val argumentNodes = call.args.map(convert).toSeq - val keywordArgNodes = call.keywords.flatMap { keyword => - if (keyword.arg.isDefined) { - Some((keyword.arg.get, convert(keyword.value))) - } else { - // keyword.arg == None. This is the case for func(**dict) style arguments. - // TODO implement handling for this case. - None - } - } - - call.func match { - case attribute: ast.Attribute => - createXDotYCall( - () => convert(attribute.value), - attribute.attr, - xMayHaveSideEffects = !attribute.value.isInstanceOf[ast.Name], - lineAndColOf(call), - argumentNodes, - keywordArgNodes + callNode + end convert + + def convert(tryStmt: ast.Try): NewNode = + createTry( + tryStmt.body.map(convert), + tryStmt.handlers.map(convert), + tryStmt.finalbody.map(convert), + tryStmt.orelse.map(convert), + lineAndColOf(tryStmt) ) - case _ => - val receiverNode = convert(call.func) - val name = call.func match { - case ast.Name(id, _) => id - case _ => "" - } - createCall(receiverNode, name, lineAndColOf(call), argumentNodes, keywordArgNodes) - } - } - - def convert(formattedValue: ast.FormattedValue): nodes.NewNode = { - val valueNode = convert(formattedValue.value) - - val equalSignStr = if (formattedValue.equalSign) "=" else "" - val conversionStr = formattedValue.conversion match { - case -1 => "" - case 115 => "!s" - case 114 => "!r" - case 97 => "!a" - } - - val formatSpecStr = formattedValue.format_spec match { - case Some(formatSpec) => ":" + formatSpec - case None => "" - } - - val code = "{" + codeOf(valueNode) + equalSignStr + conversionStr + formatSpecStr + "}" - - val callNode = nodeBuilder.callNode( - code, - ".formattedValue", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(formattedValue) - ) - addAstChildrenAsArguments(callNode, 1, valueNode) + def convert(assert: ast.Assert): NewNode = + val testNode = convert(assert.test) + val msgNode = assert.msg.map(convert) - callNode - } + val code = "assert " + codeOf(testNode) + msgNode.map(m => ", " + codeOf(m)).getOrElse("") + val callNode = nodeBuilder.callNode( + code, + ".assert", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(assert) + ) - def convert(joinedString: ast.JoinedString): nodes.NewNode = { - val argumentNodes = joinedString.values.map(convert) + addAstChildrenAsArguments(callNode, 1, testNode) + if msgNode.isDefined then + addAstChildrenAsArguments(callNode, 2, msgNode) + callNode + + // Lowering of import x: + // x = import("", "x") + // Lowering of import x as y: + // y = import("", "x") + // Lowering of import x, y: + // { + // x = import("", "x") + // y = import("", "y") + // } + def convert(importStmt: ast.Import): NewNode = + createTransformedImport("", importStmt.names, lineAndColOf(importStmt)) + + // Lowering of from x import y: + // y = import("x", "y") + // Lowering of from x import y as z: + // z = import("x", "y") + // Lowering of from x import y, z: + // { + // y = import("x", "y") + // z = import("x", "z") + // } + def convert(importFrom: ast.ImportFrom): NewNode = + var moduleName = "" + + for i <- 0 until importFrom.level do + moduleName = moduleName.appended('.') + moduleName += importFrom.module.getOrElse("") + + createTransformedImport(moduleName, importFrom.names, lineAndColOf(importFrom)) + + def convert(global: ast.Global): NewNode = + global.names.foreach(contextStack.addGlobalVariable) + val code = global.names.mkString("global ", ", ", "") + nodeBuilder.unknownNode(code, global.getClass.getName, lineAndColOf(global)) + + def convert(nonLocal: ast.Nonlocal): NewNode = + nonLocal.names.foreach(contextStack.addNonLocalVariable) + val code = nonLocal.names.mkString("nonlocal ", ", ", "") + nodeBuilder.unknownNode(code, nonLocal.getClass.getName, lineAndColOf(nonLocal)) + + def convert(expr: ast.Expr): nodes.NewNode = + convert(expr.value) + + def convert(pass: ast.Pass): nodes.NewNode = + nodeBuilder.callNode( + "pass", + ".pass", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(pass) + ) - val code = joinedString.prefix + joinedString.quote + argumentNodes - .map(codeOf) - .mkString("") + joinedString.quote + def convert(astBreak: ast.Break): nodes.NewNode = + nodeBuilder.controlStructureNode( + "break", + ControlStructureTypes.BREAK, + lineAndColOf(astBreak) + ) - val callNode = - nodeBuilder.callNode(code, ".formatString", DispatchTypes.STATIC_DISPATCH, lineAndColOf(joinedString)) + def convert(astContinue: ast.Continue): nodes.NewNode = + nodeBuilder.controlStructureNode( + "continue", + ControlStructureTypes.CONTINUE, + lineAndColOf(astContinue) + ) - addAstChildrenAsArguments(callNode, 1, argumentNodes) + def convert(raise: ast.RaiseP2): NewNode = ??? - callNode - } + def convert(errorStatement: ast.ErrorStatement): NewNode = + nodeBuilder.unknownNode( + errorStatement.toString, + errorStatement.getClass.getName, + lineAndColOf(errorStatement) + ) - def convert(constant: ast.Constant): nodes.NewNode = { - constant.value match { - case stringConstant: ast.StringConstant => - nodeBuilder.stringLiteralNode( - stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote, - lineAndColOf(constant) + def convert(expr: ast.iexpr): NewNode = + expr match + case node: ast.BoolOp => convert(node) + case node: ast.NamedExpr => convert(node) + case node: ast.BinOp => convert(node) + case node: ast.UnaryOp => convert(node) + case node: ast.Lambda => convert(node) + case node: ast.IfExp => convert(node) + case node: ast.Dict => convert(node) + case node: ast.Set => convert(node) + case node: ast.ListComp => convert(node) + case node: ast.SetComp => convert(node) + case node: ast.DictComp => convert(node) + case node: ast.GeneratorExp => convert(node) + case node: ast.Await => convert(node) + case node: ast.Yield => unhandled(node) + case node: ast.YieldFrom => unhandled(node) + case node: ast.Compare => convert(node) + case node: ast.Call => convert(node) + case node: ast.FormattedValue => convert(node) + case node: ast.JoinedString => convert(node) + case node: ast.Constant => convert(node) + case node: ast.Attribute => convert(node) + case node: ast.Subscript => convert(node) + case node: ast.Starred => convert(node) + case node: ast.Name => convert(node) + case node: ast.List => convert(node) + case node: ast.Tuple => convert(node) + case node: ast.Slice => unhandled(node) + case node: ast.StringExpList => convert(node) + + def convert(boolOp: ast.BoolOp): nodes.NewNode = + def boolOpToCodeAndFullName(operator: ast.iboolop): () => (String, String) = () => + operator match + case ast.And => ("and", Operators.logicalAnd) + case ast.Or => ("or", Operators.logicalOr) + + val operandNodes = boolOp.values.map(convert) + createNAryOperatorCall( + boolOpToCodeAndFullName(boolOp.op), + operandNodes, + lineAndColOf(boolOp) ) - case stringConstant: ast.JoinedStringConstant => - nodeBuilder.stringLiteralNode(stringConstant.value, lineAndColOf(constant)) - case boolConstant: ast.BoolConstant => - val boolStr = if (boolConstant.value) "True" else "False" - nodeBuilder.stringLiteralNode(boolStr, lineAndColOf(constant)) - case intConstant: ast.IntConstant => - nodeBuilder.numberLiteralNode(intConstant.value, lineAndColOf(constant)) - case floatConstant: ast.FloatConstant => - nodeBuilder.numberLiteralNode(floatConstant.value, lineAndColOf(constant)) - case imaginaryConstant: ast.ImaginaryConstant => - nodeBuilder.numberLiteralNode(imaginaryConstant.value + "j", lineAndColOf(constant)) - case ast.NoneConstant => - nodeBuilder.numberLiteralNode("None", lineAndColOf(constant)) - case ast.EllipsisConstant => - nodeBuilder.numberLiteralNode("...", lineAndColOf(constant)) - } - } - - /** TODO We currently ignore possible attribute access provider/interception mechanisms like __getattr__, - * __getattribute__ and __get__. - */ - def convert(attribute: ast.Attribute): nodes.NewNode = { - val baseNode = convert(attribute.value) - val fieldName = attribute.attr - val lineAndCol = lineAndColOf(attribute) - - val fieldAccess = createFieldAccess(baseNode, fieldName, lineAndCol) - - attribute.value match { - case name: ast.Name if name.id == "self" => - createAndRegisterMember(fieldName, lineAndCol) - case _ => - } - - fieldAccess - } - - private def createAndRegisterMember(name: String, lineAndCol: LineAndColumn): Unit = { - contextStack.findEnclosingTypeDecl() match { - case Some(typeDecl: NewTypeDecl) => - if (!members.contains(typeDecl) || !members(typeDecl).contains(name)) { - val member = nodeBuilder.memberNode(name, lineAndCol) - edgeBuilder.astEdge(member, typeDecl, contextStack.order.getAndInc) - members(typeDecl) = members.getOrElse(typeDecl, List()) ++ List(name) + + // TODO test + def convert(namedExpr: ast.NamedExpr): NewNode = + val targetNode = convert(namedExpr.target) + val valueNode = convert(namedExpr.value) + + createAssignment(targetNode, valueNode, lineAndColOf(namedExpr)) + + def convert(binOp: ast.BinOp): nodes.NewNode = + val lhsNode = convert(binOp.left) + val rhsNode = convert(binOp.right) + + val opCodeAndFullName = + binOp.op match + case ast.Add => ("+", Operators.addition) + case ast.Sub => ("-", Operators.subtraction) + case ast.Mult => ("*", Operators.multiplication) + case ast.MatMult => + ("@", ".matMult") // TODO make this a define and add policy for this + case ast.Div => ("/", Operators.division) + case ast.Mod => ("%", Operators.modulo) + case ast.Pow => ("**", Operators.exponentiation) + case ast.LShift => ("<<", Operators.shiftLeft) + case ast.RShift => (">>", Operators.arithmeticShiftRight) + case ast.BitOr => ("|", Operators.or) + case ast.BitXor => ("^", Operators.xor) + case ast.BitAnd => ("&", Operators.and) + case ast.FloorDiv => + ("//", ".floorDiv") // TODO make this a define and add policy for this + + createBinaryOperatorCall(lhsNode, () => opCodeAndFullName, rhsNode, lineAndColOf(binOp)) + end convert + + def convert(unaryOp: ast.UnaryOp): nodes.NewNode = + val operandNode = convert(unaryOp.operand) + + val (operatorCode, methodFullName) = + unaryOp.op match + case ast.Invert => ("~", Operators.not) + case ast.Not => ("not ", Operators.logicalNot) + case ast.UAdd => ("+", Operators.plus) + case ast.USub => ("-", Operators.minus) + + val code = operatorCode + codeOf(operandNode) + val callNode = nodeBuilder.callNode( + code, + methodFullName, + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(unaryOp) + ) + + addAstChildrenAsArguments(callNode, 1, operandNode) + + callNode + end convert + + def convert(lambda: ast.Lambda): NewNode = + // TODO test lambda expression. + val lambdaCounter = contextStack.getAndIncLambdaCounter() + val lambdaNumberSuffix = + if lambdaCounter == 0 then + "" + else + lambdaCounter.toString + + val name = "" + lambdaNumberSuffix + val (_, methodRefNode) = createMethodAndMethodRef( + name, + Some(name), + createParameterProcessingFunction(lambda.args, isStatic = false), + () => Iterable.single(convert(new ast.Return(lambda.body, lambda.attributeProvider))), + returns = None, + isAsync = false, + lineAndColOf(lambda) + ) + methodRefNode + end convert + + // TODO test + def convert(ifExp: ast.IfExp): NewNode = + val bodyNode = convert(ifExp.body) + val testNode = convert(ifExp.test) + val orElseNode = convert(ifExp.orelse) + + val code = codeOf(bodyNode) + " if " + codeOf(testNode) + " else " + codeOf(orElseNode) + val callNode = nodeBuilder.callNode( + code, + Operators.conditional, + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(ifExp) + ) + + // testNode is first argument to match semantics of Operators.conditional. + addAstChildrenAsArguments(callNode, 1, testNode, bodyNode, orElseNode) + + callNode + + /** Lowering of {x:1, y:2, **z}: { tmp = {} tmp[x] = 1 tmp[y] = 2 tmp.update(z) tmp } + */ + // TODO test + def convert(dict: ast.Dict): NewNode = + val tmpVariableName = getUnusedName() + val dictOperatorCall = + createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dict)) + val dictVariableAssigNode = + createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dict)) + + val dictElementAssignNodes = dict.keys.zip(dict.values).map { case (key, value) => + key match + case Some(key) => + val indexAccessNode = createIndexAccess( + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), + convert(key), + lineAndColOf(dict) + ) + + createAssignment(indexAccessNode, convert(value), lineAndColOf(dict)) + case None => + createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), + "update", + xMayHaveSideEffects = false, + lineAndColOf(dict), + convert(value) :: Nil, + Nil + ) } - case _ => - } - } - - def convert(subscript: ast.Subscript): NewNode = { - createIndexAccess(convert(subscript.value), convert(subscript.slice), lineAndColOf(subscript)) - } - - def convert(starred: ast.Starred): NewNode = { - val memoryOperation = memOpMap.get(starred).get - - memoryOperation match { - case Load => - val unrollOperand = convert(starred.value) - createStarredUnpackOperatorCall(unrollOperand, lineAndColOf(starred)) - case Store => - unhandled(starred) - case Del => - // This case is not possible since star operator is not allowed in delete statement. - unhandled(starred) - } - } - - def convert(name: ast.Name): nodes.NewNode = { - val memoryOperation = memOpMap.get(name).get - val identifier = createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) - if (contextStack.isClassContext && memoryOperation == Store) { - createAndRegisterMember(identifier.name, lineAndColOf(name)) - } - identifier - } - - // TODO test - def convert(list: ast.List): nodes.NewNode = { - // Must be a List as part of a Load memory operation because a List literal - // is not permitted as argument to a Del and List as part of a Store does not - // reach here. - assert(memOpMap.get(list).get == Load) - val listElementNodes = list.elts.map(convert) - val code = listElementNodes.map(codeOf).mkString("[", ", ", "]") - - val callNode = - nodeBuilder.callNode(code, ".listLiteral", DispatchTypes.STATIC_DISPATCH, lineAndColOf(list)) - - addAstChildrenAsArguments(callNode, 1, listElementNodes) - - callNode - } - - // TODO test - def convert(tuple: ast.Tuple): NewNode = { - // Must be a tuple as part of a Load or Del memory operation because Tuples in - // store contexts are not supposed to reach here. They need to be lowered by - // createValueToTargetsDecomposition. - assert(memOpMap.get(tuple).get == Load || memOpMap.get(tuple).get == Del) - val tupleElementNodes = tuple.elts.map(convert) - val code = if (tupleElementNodes.size != 1) { - tupleElementNodes.map(codeOf).mkString("(", ", ", ")") - } else { - "(" + codeOf(tupleElementNodes.head) + ",)" - } - - val callNode = - nodeBuilder.callNode(code, ".tupleLiteral", DispatchTypes.STATIC_DISPATCH, lineAndColOf(tuple)) - - addAstChildrenAsArguments(callNode, 1, tupleElementNodes) - - callNode - } - - def convert(slice: ast.Slice): NewNode = ??? - - def convert(stringExpList: ast.StringExpList): NewNode = { - val stringNodes = stringExpList.elts.map(convert) - val code = stringNodes.map(codeOf).mkString(" ") - - val callNode = nodeBuilder.callNode( - code, - ".stringExpressionList", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(stringExpList) - ) - addAstChildrenAsArguments(callNode, 1, stringNodes) - - callNode - } - - // TODO Since there is now real concept of reflecting exception handlers - // semantically in the CPG we just make sure that the variable scoping - // is right and that we convert the exception handler body. - // TODO tests - def convert(exceptHandler: ast.ExceptHandler): NewNode = { - contextStack.pushSpecialContext() - val specialTargetLocals = mutable.ArrayBuffer.empty[nodes.NewLocal] - if (exceptHandler.name.isDefined) { - val localNode = nodeBuilder.localNode(exceptHandler.name.get, None) - specialTargetLocals.append(localNode) - contextStack.addSpecialVariable(localNode) - } - - val blockNode = createBlock(exceptHandler.body.map(convert), lineAndColOf(exceptHandler)) - addAstChildNodes(blockNode, 1, specialTargetLocals) - - contextStack.pop() - - blockNode - } - - def convert(parameters: ast.Arguments, startIndex: Int): Iterable[nodes.NewMethodParameterIn] = { - val autoIncIndex = new AutoIncIndex(startIndex) - - parameters.posonlyargs.map(convertPosOnlyArg(_, autoIncIndex)) ++ - parameters.args.map(convertNormalArg(_, autoIncIndex)) ++ - parameters.vararg.map(convertVarArg(_, autoIncIndex)) ++ - parameters.kwonlyargs.map(convertKeywordOnlyArg) ++ - parameters.kw_arg.map(convertKwArg) - } - - // TODO for now the different arg convert functions are all the same but - // will all be slightly different in the future when we can represent the - // different types in the cpg. - def convertPosOnlyArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = { - nodeBuilder.methodParameterNode( - arg.arg, - isVariadic = false, - lineAndColOf(arg), - Option(index.getAndInc), - arg.annotation - ) - } - - def convertNormalArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = { - nodeBuilder.methodParameterNode( - arg.arg, - isVariadic = false, - lineAndColOf(arg), - Option(index.getAndInc), - arg.annotation - ) - } + val dictInstanceReturnIdentifierNode = + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)) + + val blockElements = mutable.ArrayBuffer.empty[nodes.NewNode] + blockElements.append(dictVariableAssigNode) + blockElements.appendAll(dictElementAssignNodes) + blockElements.append(dictInstanceReturnIdentifierNode) + createBlock(blockElements, lineAndColOf(dict)) + end convert + + // TODO test + def convert(set: ast.Set): nodes.NewNode = + val setElementNodes = set.elts.map(convert) + val code = setElementNodes.map(codeOf).mkString("{", ", ", "}") + + val callNode = nodeBuilder.callNode( + code, + ".setLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(set) + ) - def convertVarArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = { - nodeBuilder.methodParameterNode(arg.arg, isVariadic = true, lineAndColOf(arg), Option(index.getAndInc)) - } + addAstChildrenAsArguments(callNode, 1, setElementNodes) - def convertKeywordOnlyArg(arg: ast.Arg): nodes.NewMethodParameterIn = { - nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) - } + callNode - def convertKwArg(arg: ast.Arg): nodes.NewMethodParameterIn = { - nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) - } + /** Lowering of [x for y in l for x in y]: { tmp = [] ( for y in l: for x in y: + * tmp.append(x) ) tmp } + */ + // TODO test + def convert(listComp: ast.ListComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() - def convert(keyword: ast.Keyword): NewNode = ??? + // Create tmp = list() + val listOperatorCall = + createLiteralOperatorCall("[", "]", ".listLiteral", lineAndColOf(listComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, listOperatorCall, lineAndColOf(listComp)) - def convert(alias: ast.Alias): NewNode = ??? + // Create tmp.append(x) + val listVarAppendCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(listComp)), + "append", + xMayHaveSideEffects = false, + lineAndColOf(listComp), + convert(listComp.elt) :: Nil, + Nil + ) - def convert(typeIgnore: ast.TypeIgnore): NewNode = ??? + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + listVarAppendCallNode, + listComp.generators, + lineAndColOf(listComp) + ) - private def calculateFullNameFromContext(name: String): String = { - val contextQualName = contextStack.qualName - if (contextQualName != "") { - relFileName + ":" + contextQualName + "." + name - } else { - relFileName + ":" + name - } - } -} + contextStack.pop() -object PythonAstVisitor { - val builtinPrefix = "__builtin." - val typingPrefix = "typing." - val metaClassSuffix = "" - - // This list contains all functions from https://docs.python.org/3/library/functions.html#built-in-funcs - // for python version 3.9.5. - // There is a corresponding list in policies which needs to be updated if this one is updated and vice versa. - val builtinFunctionsV3: Iterable[String] = Iterable( - "abs", - "aiter", - "all", - "anext", - "any", - "ascii", - "bin", - "breakpoint", - "callable", - "chr", - "classmethod", - "compile", - "delattr", - "dir", - "divmod", - "enumerate", - "eval", - "exec", - "filter", - "format", - "getattr", - "globals", - "hasattr", - "hash", - "help", - "hex", - "id", - "input", - "isinstance", - "issubclass", - "iter", - "len", - "locals", - "map", - "max", - "memoryview", - "min", - "next", - "oct", - "open", - "ord", - "pow", - "print", - "repr", - "reversed", - "round", - "setattr", - "sorted", - "staticmethod", - "sum", - "super", - "vars", - "zip", - "__import__" - ) - // This list contains all classes from https://docs.python.org/3/library/functions.html#built-in-funcs - // for python version 3.9.5. - val builtinClassesV3: Iterable[String] = Iterable( - "bool", - "bytearray", - "bytes", - "complex", - "dict", - "float", - "frozenset", - "int", - "list", - "memoryview", - "object", - "property", - "range", - "set", - "slice", - "str", - "tuple", - "type" - ) - // This list contains all functions from https://docs.python.org/2.7/library/functions.html - val builtinFunctionsV2: Iterable[String] = Iterable( - "abs", - "all", - "any", - "bin", - "callable", - "chr", - "classmethod", - "cmp", - "compile", - "delattr", - "dir", - "divmod", - "enumerate", - "eval", - // This one is special because it is not from the above mentioned list. - // This is because exec is a statement type in V2 but our parser provides - // it to us as a normal call so that we can model it as builtin. - "exec", - "execfile", - "filter", - "format", - "getattr", - "globals", - "hasattr", - "hash", - "help", - "hex", - "id", - "input", - "isinstance", - "issubclass", - "iter", - "len", - "locals", - "map", - "max", - "min", - "next", - "oct", - "open", - "ord", - "pow", - "print", - "range", - "raw_input", - "reduce", - "reload", - "repr", - "reversed", - "round", - "setattr", - "sorted", - "staticmethod", - "sum", - "super", - "unichr", - "vars", - "zip", - "__import__" - ) - // This list contains all classes from https://docs.python.org/2.7/library/functions.html - val builtinClassesV2: Iterable[String] = Iterable( - "bool", - "bytearray", - "complex", - "dict", - "file", - "float", - "frozenset", - "int", - "list", - "long", - "memoryview", - "object", - "property", - "set", - "slice", - "str", - "tuple", - "type", - "unicode", - "xrange" - ) - - lazy val allBuiltinClasses: Set[String] = (builtinClassesV2 ++ builtinClassesV3).toSet - - lazy val typingClassesV3: Set[String] = Set( - "Annotated", - "Any", - "Callable", - "ClassVar", - "Final", - "ForwardRef", - "Generic", - "Literal", - "Optional", - "Protocol", - "Tuple", - "Type", - "TypeVar", - "Union", - "AbstractSet", - "ByteString", - "Container", - "ContextManager", - "Hashable", - "ItemsView", - "Iterable", - "Iterator", - "KeysView", - "Mapping", - "MappingView", - "MutableMapping", - "MutableSequence", - "MutableSet", - "Sequence", - "Sized", - "ValuesView", - "Awaitable", - "AsyncIterator", - "AsyncIterable", - "Coroutine", - "Collection", - "AsyncGenerator", - "AsyncContextManager", - "Reversible", - "SupportsAbs", - "SupportsBytes", - "SupportsComplex", - "SupportsFloat", - "SupportsIndex", - "SupportsInt", - "SupportsRound", - "ChainMap", - "Counter", - "Deque", - "Dict", - "DefaultDict", - "List", - "OrderedDict", - "Set", - "FrozenSet", - "NamedTuple", - "TypedDict", - "Generator", - "BinaryIO", - "IO", - "Match", - "Pattern", - "TextIO", - "AnyStr", - "cast", - "final", - "get_args", - "get_origin", - "get_type_hints", - "NewType", - "no_type_check", - "no_type_check_decorator", - "NoReturn", - "overload", - "runtime_checkable", - "Text", - "TYPE_CHECKING" - ) -} + comprehensionBlockNode + end convert + + /** Lowering of {x for y in l for x in y}: { tmp = {} ( for y in l: for x in y: + * tmp.add(x) ) tmp } + */ + // TODO test + def convert(setComp: ast.SetComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + val setOperatorCall = + createLiteralOperatorCall("{", "}", ".setLiteral", lineAndColOf(setComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, setOperatorCall, lineAndColOf(setComp)) + + // Create tmp.add(x) + val setVarAddCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(setComp)), + "add", + xMayHaveSideEffects = false, + lineAndColOf(setComp), + convert(setComp.elt) :: Nil, + Nil + ) + + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + setVarAddCallNode, + setComp.generators, + lineAndColOf(setComp) + ) + + contextStack.pop() + + comprehensionBlockNode + end convert + + /** Lowering of {k:v for y in l for k, v in y}: { tmp = {} ( for y in l: for k, v in + * y: tmp[k] = v ) tmp } + */ + // TODO test + def convert(dictComp: ast.DictComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + val dictOperatorCall = + createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dictComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dictComp)) + + // Create tmp[k] = v + val dictAssigNode = createAssignment( + createIndexAccess( + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dictComp)), + convert(dictComp.key), + lineAndColOf(dictComp) + ), + convert(dictComp.value), + lineAndColOf(dictComp) + ) + + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + dictAssigNode, + dictComp.generators, + lineAndColOf(dictComp) + ) + + contextStack.pop() + + comprehensionBlockNode + end convert + + /** Lowering of (x for y in l for x in y): { tmp = .genExp ( for y in l: + * for x in y: tmp.append(x) ) tmp } This lowering is not quite correct as it ignores the lazy + * evaluation of the generator expression. Instead it just mimics the list comprehension + * lowering but for now this is good enough. + */ + // TODO test + def convert(generatorExp: ast.GeneratorExp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + // Create tmp = list() + val genExpOperatorCall = + nodeBuilder.callNode( + ".genExp", + ".genExp", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(generatorExp) + ) + + val variableAssignNode = + createAssignmentToIdentifier( + tmpVariableName, + genExpOperatorCall, + lineAndColOf(generatorExp) + ) + + // Create tmp.append(x) + val genExpAppendCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(generatorExp)), + "append", + xMayHaveSideEffects = false, + lineAndColOf(generatorExp), + convert(generatorExp.elt) :: Nil, + Nil + ) + + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + genExpAppendCallNode, + generatorExp.generators, + lineAndColOf(generatorExp) + ) + + contextStack.pop() + + comprehensionBlockNode + end convert + + def convert(await: ast.Await): NewNode = + // Since the CPG format does not provide means to model async/await, + // we for now treat it as non existing. + convert(await.value) + + def convert(yieldExpr: ast.Yield): NewNode = ??? + + def convert(yieldFrom: ast.YieldFrom): NewNode = ??? + + // In case of a single compare operation there is no lowering applied. + // So e.g. x < y stay untouched. + // Otherwise the lowering is as follows: + // Src AST: + // x < y < z < a + // Lowering: + // { + // tmp1 = y + // x < tmp1 && { + // tmp2 = z + // tmp1 < tmp2 && { + // tmp2 < a + // } + // } + // } + def convert(compare: ast.Compare): NewNode = + assert(compare.ops.size == compare.comparators.size) + var lhsNode = convert(compare.left) + + val topLevelExprNodes = + lowerComparatorChain(lhsNode, compare.ops, compare.comparators, lineAndColOf(compare)) + if topLevelExprNodes.size > 1 then + createBlock(topLevelExprNodes, lineAndColOf(compare)) + else + topLevelExprNodes.head + + private def compopToOpCodeAndFullName(compareOp: ast.icompop): () => (String, String) = () => + compareOp match + case ast.Eq => ("==", Operators.equals) + case ast.NotEq => ("!=", Operators.notEquals) + case ast.Lt => ("<", Operators.lessThan) + case ast.LtE => ("<=", Operators.lessEqualsThan) + case ast.Gt => (">", Operators.greaterThan) + case ast.GtE => (">=", Operators.greaterEqualsThan) + case ast.Is => ("is", ".is") + case ast.IsNot => ("is not", ".isNot") + case ast.In => ("in", ".in") + case ast.NotIn => ("not in", ".notIn") + + def lowerComparatorChain( + lhsNode: nodes.NewNode, + compOperators: Iterable[ast.icompop], + comparators: Iterable[ast.iexpr], + lineAndColumn: LineAndColumn + ): Iterable[nodes.NewNode] = + val rhsNode = convert(comparators.head) + + if compOperators.size == 1 then + val compareNode = + createBinaryOperatorCall( + lhsNode, + compopToOpCodeAndFullName(compOperators.head), + rhsNode, + lineAndColumn + ) + Iterable.single(compareNode) + else + val tmpVariableName = getUnusedName() + val assignmentNode = + createAssignmentToIdentifier(tmpVariableName, rhsNode, lineAndColumn) + + val tmpIdentifierCompare1 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val compareNode = createBinaryOperatorCall( + lhsNode, + compopToOpCodeAndFullName(compOperators.head), + tmpIdentifierCompare1, + lineAndColumn + ) + + val tmpIdentifierCompare2 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val childNodes = lowerComparatorChain( + tmpIdentifierCompare2, + compOperators.tail, + comparators.tail, + lineAndColumn + ) + + val blockNode = createBlock(childNodes, lineAndColumn) + + Iterable( + assignmentNode, + createBinaryOperatorCall( + compareNode, + andOpCodeAndFullName(), + blockNode, + lineAndColumn + ) + ) + end if + end lowerComparatorChain + + private def andOpCodeAndFullName(): () => (String, String) = () => + ("and", Operators.logicalAnd) + + /** TODO For now this function compromises on the correctness of the lowering in order to + * get some data flow tracking going. + * 1. For constructs like x.func() we assume x to be the instance which is passed into + * func. This is not true since the instance method object gets the instance already + * bound/captured during function access. This becomes relevant for constructs like: + * x.func = y.func <- y.func is class method object x.func() In this case the instance + * passed into func is y and not x. We cannot represent this in th CPG and thus stick + * to the assumption that the part before the "." and the bound/captured instance will + * be the same. For reference see: + * https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy + * search for "Instance methods" + */ + def convert(call: ast.Call): nodes.NewNode = + val argumentNodes = call.args.map(convert).toSeq + val keywordArgNodes = call.keywords.flatMap { keyword => + if keyword.arg.isDefined then + Some((keyword.arg.get, convert(keyword.value))) + else + // keyword.arg == None. This is the case for func(**dict) style arguments. + // TODO implement handling for this case. + None + } + + call.func match + case attribute: ast.Attribute => + createXDotYCall( + () => convert(attribute.value), + attribute.attr, + xMayHaveSideEffects = !attribute.value.isInstanceOf[ast.Name], + lineAndColOf(call), + argumentNodes, + keywordArgNodes + ) + case _ => + val receiverNode = convert(call.func) + val name = call.func match + case ast.Name(id, _) => id + case _ => "" + createCall(receiverNode, name, lineAndColOf(call), argumentNodes, keywordArgNodes) + end convert + + def convert(formattedValue: ast.FormattedValue): nodes.NewNode = + val valueNode = convert(formattedValue.value) + + val equalSignStr = if formattedValue.equalSign then "=" else "" + val conversionStr = formattedValue.conversion match + case -1 => "" + case 115 => "!s" + case 114 => "!r" + case 97 => "!a" + + val formatSpecStr = formattedValue.format_spec match + case Some(formatSpec) => ":" + formatSpec + case None => "" + + val code = "{" + codeOf(valueNode) + equalSignStr + conversionStr + formatSpecStr + "}" + + val callNode = nodeBuilder.callNode( + code, + ".formattedValue", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(formattedValue) + ) + + addAstChildrenAsArguments(callNode, 1, valueNode) + + callNode + end convert + + def convert(joinedString: ast.JoinedString): nodes.NewNode = + val argumentNodes = joinedString.values.map(convert) + + val code = joinedString.prefix + joinedString.quote + argumentNodes + .map(codeOf) + .mkString("") + joinedString.quote + + val callNode = + nodeBuilder.callNode( + code, + ".formatString", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(joinedString) + ) + + addAstChildrenAsArguments(callNode, 1, argumentNodes) + + callNode + + def convert(constant: ast.Constant): nodes.NewNode = + constant.value match + case stringConstant: ast.StringConstant => + nodeBuilder.stringLiteralNode( + stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote, + lineAndColOf(constant) + ) + case stringConstant: ast.JoinedStringConstant => + nodeBuilder.stringLiteralNode(stringConstant.value, lineAndColOf(constant)) + case boolConstant: ast.BoolConstant => + val boolStr = if boolConstant.value then "True" else "False" + nodeBuilder.stringLiteralNode(boolStr, lineAndColOf(constant)) + case intConstant: ast.IntConstant => + nodeBuilder.numberLiteralNode(intConstant.value, lineAndColOf(constant)) + case floatConstant: ast.FloatConstant => + nodeBuilder.numberLiteralNode(floatConstant.value, lineAndColOf(constant)) + case imaginaryConstant: ast.ImaginaryConstant => + nodeBuilder.numberLiteralNode(imaginaryConstant.value + "j", lineAndColOf(constant)) + case ast.NoneConstant => + nodeBuilder.numberLiteralNode("None", lineAndColOf(constant)) + case ast.EllipsisConstant => + nodeBuilder.numberLiteralNode("...", lineAndColOf(constant)) + + /** TODO We currently ignore possible attribute access provider/interception mechanisms like + * __getattr__, __getattribute__ and __get__. + */ + def convert(attribute: ast.Attribute): nodes.NewNode = + val baseNode = convert(attribute.value) + val fieldName = attribute.attr + val lineAndCol = lineAndColOf(attribute) + + val fieldAccess = createFieldAccess(baseNode, fieldName, lineAndCol) + + attribute.value match + case name: ast.Name if name.id == "self" => + createAndRegisterMember(fieldName, lineAndCol) + case _ => + + fieldAccess + + private def createAndRegisterMember(name: String, lineAndCol: LineAndColumn): Unit = + contextStack.findEnclosingTypeDecl() match + case Some(typeDecl: NewTypeDecl) => + if !members.contains(typeDecl) || !members(typeDecl).contains(name) then + val member = nodeBuilder.memberNode(name, lineAndCol) + edgeBuilder.astEdge(member, typeDecl, contextStack.order.getAndInc) + members(typeDecl) = members.getOrElse(typeDecl, List()) ++ List(name) + case _ => + + def convert(subscript: ast.Subscript): NewNode = + createIndexAccess( + convert(subscript.value), + convert(subscript.slice), + lineAndColOf(subscript) + ) + + def convert(starred: ast.Starred): NewNode = + val memoryOperation = memOpMap.get(starred).get + + memoryOperation match + case Load => + val unrollOperand = convert(starred.value) + createStarredUnpackOperatorCall(unrollOperand, lineAndColOf(starred)) + case Store => + unhandled(starred) + case Del => + // This case is not possible since star operator is not allowed in delete statement. + unhandled(starred) + + def convert(name: ast.Name): nodes.NewNode = + val memoryOperation = memOpMap.get(name).get + val identifier = createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) + if contextStack.isClassContext && memoryOperation == Store then + createAndRegisterMember(identifier.name, lineAndColOf(name)) + identifier + + // TODO test + def convert(list: ast.List): nodes.NewNode = + // Must be a List as part of a Load memory operation because a List literal + // is not permitted as argument to a Del and List as part of a Store does not + // reach here. + assert(memOpMap.get(list).get == Load) + val listElementNodes = list.elts.map(convert) + val code = listElementNodes.map(codeOf).mkString("[", ", ", "]") + + val callNode = + nodeBuilder.callNode( + code, + ".listLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(list) + ) + + addAstChildrenAsArguments(callNode, 1, listElementNodes) + + callNode + + // TODO test + def convert(tuple: ast.Tuple): NewNode = + // Must be a tuple as part of a Load or Del memory operation because Tuples in + // store contexts are not supposed to reach here. They need to be lowered by + // createValueToTargetsDecomposition. + assert(memOpMap.get(tuple).get == Load || memOpMap.get(tuple).get == Del) + val tupleElementNodes = tuple.elts.map(convert) + val code = if tupleElementNodes.size != 1 then + tupleElementNodes.map(codeOf).mkString("(", ", ", ")") + else + "(" + codeOf(tupleElementNodes.head) + ",)" + + val callNode = + nodeBuilder.callNode( + code, + ".tupleLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(tuple) + ) + + addAstChildrenAsArguments(callNode, 1, tupleElementNodes) + + callNode + end convert + + def convert(slice: ast.Slice): NewNode = ??? + + def convert(stringExpList: ast.StringExpList): NewNode = + val stringNodes = stringExpList.elts.map(convert) + val code = stringNodes.map(codeOf).mkString(" ") + + val callNode = nodeBuilder.callNode( + code, + ".stringExpressionList", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(stringExpList) + ) + + addAstChildrenAsArguments(callNode, 1, stringNodes) + + callNode + + // TODO Since there is now real concept of reflecting exception handlers + // semantically in the CPG we just make sure that the variable scoping + // is right and that we convert the exception handler body. + // TODO tests + def convert(exceptHandler: ast.ExceptHandler): NewNode = + contextStack.pushSpecialContext() + val specialTargetLocals = mutable.ArrayBuffer.empty[nodes.NewLocal] + if exceptHandler.name.isDefined then + val localNode = nodeBuilder.localNode(exceptHandler.name.get, None) + specialTargetLocals.append(localNode) + contextStack.addSpecialVariable(localNode) + + val blockNode = createBlock(exceptHandler.body.map(convert), lineAndColOf(exceptHandler)) + addAstChildNodes(blockNode, 1, specialTargetLocals) + + contextStack.pop() + + blockNode + + def convert(parameters: ast.Arguments, startIndex: Int): Iterable[nodes.NewMethodParameterIn] = + val autoIncIndex = new AutoIncIndex(startIndex) + + parameters.posonlyargs.map(convertPosOnlyArg(_, autoIncIndex)) ++ + parameters.args.map(convertNormalArg(_, autoIncIndex)) ++ + parameters.vararg.map(convertVarArg(_, autoIncIndex)) ++ + parameters.kwonlyargs.map(convertKeywordOnlyArg) ++ + parameters.kw_arg.map(convertKwArg) + + // TODO for now the different arg convert functions are all the same but + // will all be slightly different in the future when we can represent the + // different types in the cpg. + def convertPosOnlyArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = false, + lineAndColOf(arg), + Option(index.getAndInc), + arg.annotation + ) + + def convertNormalArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = false, + lineAndColOf(arg), + Option(index.getAndInc), + arg.annotation + ) + + def convertVarArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = true, + lineAndColOf(arg), + Option(index.getAndInc) + ) + + def convertKeywordOnlyArg(arg: ast.Arg): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + + def convertKwArg(arg: ast.Arg): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + + def convert(keyword: ast.Keyword): NewNode = ??? + + def convert(alias: ast.Alias): NewNode = ??? + + def convert(typeIgnore: ast.TypeIgnore): NewNode = ??? + + private def calculateFullNameFromContext(name: String): String = + val contextQualName = contextStack.qualName + if contextQualName != "" then + relFileName + ":" + contextQualName + "." + name + else + relFileName + ":" + name +end PythonAstVisitor + +object PythonAstVisitor: + val builtinPrefix = "__builtin." + val typingPrefix = "typing." + val metaClassSuffix = "" + + // This list contains all functions from https://docs.python.org/3/library/functions.html#built-in-funcs + // for python version 3.9.5. + // There is a corresponding list in policies which needs to be updated if this one is updated and vice versa. + val builtinFunctionsV3: Iterable[String] = Iterable( + "abs", + "aiter", + "all", + "anext", + "any", + "ascii", + "bin", + "breakpoint", + "callable", + "chr", + "classmethod", + "compile", + "delattr", + "dir", + "divmod", + "enumerate", + "eval", + "exec", + "filter", + "format", + "getattr", + "globals", + "hasattr", + "hash", + "help", + "hex", + "id", + "input", + "isinstance", + "issubclass", + "iter", + "len", + "locals", + "map", + "max", + "memoryview", + "min", + "next", + "oct", + "open", + "ord", + "pow", + "print", + "repr", + "reversed", + "round", + "setattr", + "sorted", + "staticmethod", + "sum", + "super", + "vars", + "zip", + "__import__" + ) + // This list contains all classes from https://docs.python.org/3/library/functions.html#built-in-funcs + // for python version 3.9.5. + val builtinClassesV3: Iterable[String] = Iterable( + "bool", + "bytearray", + "bytes", + "complex", + "dict", + "float", + "frozenset", + "int", + "list", + "memoryview", + "object", + "property", + "range", + "set", + "slice", + "str", + "tuple", + "type" + ) + // This list contains all functions from https://docs.python.org/2.7/library/functions.html + val builtinFunctionsV2: Iterable[String] = Iterable( + "abs", + "all", + "any", + "bin", + "callable", + "chr", + "classmethod", + "cmp", + "compile", + "delattr", + "dir", + "divmod", + "enumerate", + "eval", + // This one is special because it is not from the above mentioned list. + // This is because exec is a statement type in V2 but our parser provides + // it to us as a normal call so that we can model it as builtin. + "exec", + "execfile", + "filter", + "format", + "getattr", + "globals", + "hasattr", + "hash", + "help", + "hex", + "id", + "input", + "isinstance", + "issubclass", + "iter", + "len", + "locals", + "map", + "max", + "min", + "next", + "oct", + "open", + "ord", + "pow", + "print", + "range", + "raw_input", + "reduce", + "reload", + "repr", + "reversed", + "round", + "setattr", + "sorted", + "staticmethod", + "sum", + "super", + "unichr", + "vars", + "zip", + "__import__" + ) + // This list contains all classes from https://docs.python.org/2.7/library/functions.html + val builtinClassesV2: Iterable[String] = Iterable( + "bool", + "bytearray", + "complex", + "dict", + "file", + "float", + "frozenset", + "int", + "list", + "long", + "memoryview", + "object", + "property", + "set", + "slice", + "str", + "tuple", + "type", + "unicode", + "xrange" + ) + + lazy val allBuiltinClasses: Set[String] = (builtinClassesV2 ++ builtinClassesV3).toSet + + lazy val typingClassesV3: Set[String] = Set( + "Annotated", + "Any", + "Callable", + "ClassVar", + "Final", + "ForwardRef", + "Generic", + "Literal", + "Optional", + "Protocol", + "Tuple", + "Type", + "TypeVar", + "Union", + "AbstractSet", + "ByteString", + "Container", + "ContextManager", + "Hashable", + "ItemsView", + "Iterable", + "Iterator", + "KeysView", + "Mapping", + "MappingView", + "MutableMapping", + "MutableSequence", + "MutableSet", + "Sequence", + "Sized", + "ValuesView", + "Awaitable", + "AsyncIterator", + "AsyncIterable", + "Coroutine", + "Collection", + "AsyncGenerator", + "AsyncContextManager", + "Reversible", + "SupportsAbs", + "SupportsBytes", + "SupportsComplex", + "SupportsFloat", + "SupportsIndex", + "SupportsInt", + "SupportsRound", + "ChainMap", + "Counter", + "Deque", + "Dict", + "DefaultDict", + "List", + "OrderedDict", + "Set", + "FrozenSet", + "NamedTuple", + "TypedDict", + "Generator", + "BinaryIO", + "IO", + "Match", + "Pattern", + "TextIO", + "AnyStr", + "cast", + "final", + "get_args", + "get_origin", + "get_type_hints", + "NewType", + "no_type_check", + "no_type_check_decorator", + "NoReturn", + "overload", + "runtime_checkable", + "Text", + "TYPE_CHECKING" + ) +end PythonAstVisitor diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala index 6a6bc63a..4d4ab3f6 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala @@ -8,586 +8,666 @@ import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Dispatch import scala.collection.immutable.{::, Nil} import scala.collection.mutable -trait PythonAstVisitorHelpers { this: PythonAstVisitor => - - protected def codeOf(node: NewNode): String = { - node.asInstanceOf[AstNodeNew].code - } - - protected def lineAndColOf(node: ast.iattributes): LineAndColumn = { - // node.end_col_offset - 1 because the end column offset of the parser points - // behind the last symbol. - LineAndColumn(node.lineno, node.col_offset, node.end_lineno, node.end_col_offset - 1) - } - - private var tmpCounter = 0 - - protected def getUnusedName(prefix: String = null): String = { - // TODO check that result name does not collide with existing variables. - val result = if (prefix != null) { - s"${prefix}_tmp$tmpCounter" - } else { - s"tmp$tmpCounter" - } - tmpCounter += 1 - result - } - - protected def createTry( - body: Iterable[NewNode], - handlers: Iterable[NewNode], - finalBlock: Iterable[NewNode], - orElseBlock: Iterable[NewNode], - lineAndColumn: LineAndColumn - ): NewNode = { - val controlStructureNode = - nodeBuilder.controlStructureNode("try: ...", ControlStructureTypes.TRY, lineAndColumn) - - val bodyBlockNode = createBlock(body, lineAndColumn) - val handlersBlockNode = createBlock(handlers, lineAndColumn) - val finalBlockNode = createBlock(finalBlock, lineAndColumn) - val orElseBlockNode = createBlock(orElseBlock, lineAndColumn) - - addAstChildNodes(controlStructureNode, 1, bodyBlockNode, handlersBlockNode, finalBlockNode, orElseBlockNode) - - controlStructureNode - } - - protected def createTransformedImport( - from: String, - names: Iterable[ast.Alias], - lineAndCol: LineAndColumn - ): NewNode = { - val importAssignNodes = - names.map { alias => - val importedAsIdentifierName = alias.asName.getOrElse(alias.name) - val importAssignLhsIdentifierNode = - createIdentifierNode(importedAsIdentifierName, Store, lineAndCol) - - val arguments = Seq( - nodeBuilder.stringLiteralNode(from, lineAndCol), - nodeBuilder.stringLiteralNode(alias.name, lineAndCol) - ) ++ (alias.asName match { - case Some(aliasName) => Seq(nodeBuilder.stringLiteralNode(aliasName, lineAndCol)) - case None => Seq() - }) - - val importCallNode = - createCall(createIdentifierNode("import", Load, lineAndCol), "import", lineAndCol, arguments, Nil) - - val assignNode = createAssignment(importAssignLhsIdentifierNode, importCallNode, lineAndCol) - assignNode - } - - if (importAssignNodes.size > 1) { - createBlock(importAssignNodes, lineAndCol) - } else { - // Empty importAssignNodes cannot happen. - importAssignNodes.head - } - } - - // Used for assign statements, for loop target assignment and - // for comprehension target assignment. - // TODO handle Starred target - protected def createValueToTargetsDecomposition( - targets: Iterable[ast.iexpr], - valueNode: NewNode, - lineAndColumn: LineAndColumn - ): Iterable[NewNode] = { - if ( - targets.size == 1 && - !targets.head.isInstanceOf[ast.Tuple] && - !targets.head.isInstanceOf[ast.List] - ) { - // No lowering or wrapping in a block is required if we have a single target and - // no decomposition. - val targetNode = convert(targets.head) - - Iterable.single(createAssignment(targetNode, valueNode, lineAndColumn)) - } else { - // Lowering of x, (y,z) = a = b = c: - // Note: No surrounding block is created. This is the duty of the caller. - // tmp = c - // x = tmp[0] - // y = tmp[1][0] - // z = tmp[1][1] - // a = c - // b = c - // Lowering of for x, (y, z) in c: - // tmp = c - // x = tmp[0] - // y = tmp[1][0] - // z = tmp[1][1] - val tmpVariableName = getUnusedName() - - val tmpVariableAssignNode = - createAssignmentToIdentifier(tmpVariableName, valueNode, lineAndColumn) - - val loweredAssignNodes = mutable.ArrayBuffer.empty[NewNode] - loweredAssignNodes.append(tmpVariableAssignNode) - - targets.foreach { target => - val targetWithAccessChains = getTargetsWithAccessChains(target) - targetWithAccessChains.foreach { case (trgt, accessChain) => - val targetNode = convert(trgt) - val tmpIdentifierNode = - createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val indexTmpIdentifierNode = createIndexAccessChain(tmpIdentifierNode, accessChain, lineAndColumn) - - val targetAssignNode = createAssignment(targetNode, indexTmpIdentifierNode, lineAndColumn) - loweredAssignNodes.append(targetAssignNode) - } - } - loweredAssignNodes - } - } - - protected def getTargetsWithAccessChains(target: ast.iexpr): Iterable[(ast.iexpr, List[Int])] = { - val result = mutable.ArrayBuffer.empty[(ast.iexpr, List[Int])] - getTargetsInternal(target, Nil) - - def getTargetsInternal(target: ast.iexpr, indexChain: List[Int]): Unit = { - target match { - case tuple: ast.Tuple => - var index = 0 - tuple.elts.foreach { element => - getTargetsInternal(element, index :: indexChain) +trait PythonAstVisitorHelpers: + this: PythonAstVisitor => + + protected def codeOf(node: NewNode): String = + node.asInstanceOf[AstNodeNew].code + + protected def lineAndColOf(node: ast.iattributes): LineAndColumn = + // node.end_col_offset - 1 because the end column offset of the parser points + // behind the last symbol. + LineAndColumn(node.lineno, node.col_offset, node.end_lineno, node.end_col_offset - 1) + + private var tmpCounter = 0 + + protected def getUnusedName(prefix: String = null): String = + // TODO check that result name does not collide with existing variables. + val result = if prefix != null then + s"${prefix}_tmp$tmpCounter" + else + s"tmp$tmpCounter" + tmpCounter += 1 + result + + protected def createTry( + body: Iterable[NewNode], + handlers: Iterable[NewNode], + finalBlock: Iterable[NewNode], + orElseBlock: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + val controlStructureNode = + nodeBuilder.controlStructureNode("try: ...", ControlStructureTypes.TRY, lineAndColumn) + + val bodyBlockNode = createBlock(body, lineAndColumn) + val handlersBlockNode = createBlock(handlers, lineAndColumn) + val finalBlockNode = createBlock(finalBlock, lineAndColumn) + val orElseBlockNode = createBlock(orElseBlock, lineAndColumn) + + addAstChildNodes( + controlStructureNode, + 1, + bodyBlockNode, + handlersBlockNode, + finalBlockNode, + orElseBlockNode + ) + + controlStructureNode + end createTry + + protected def createTransformedImport( + from: String, + names: Iterable[ast.Alias], + lineAndCol: LineAndColumn + ): NewNode = + val importAssignNodes = + names.map { alias => + val importedAsIdentifierName = alias.asName.getOrElse(alias.name) + val importAssignLhsIdentifierNode = + createIdentifierNode(importedAsIdentifierName, Store, lineAndCol) + + val arguments = Seq( + nodeBuilder.stringLiteralNode(from, lineAndCol), + nodeBuilder.stringLiteralNode(alias.name, lineAndCol) + ) ++ (alias.asName match + case Some(aliasName) => + Seq(nodeBuilder.stringLiteralNode(aliasName, lineAndCol)) + case None => Seq() + ) + + val importCallNode = + createCall( + createIdentifierNode("import", Load, lineAndCol), + "import", + lineAndCol, + arguments, + Nil + ) + + val assignNode = + createAssignment(importAssignLhsIdentifierNode, importCallNode, lineAndCol) + assignNode + } + + if importAssignNodes.size > 1 then + createBlock(importAssignNodes, lineAndCol) + else + // Empty importAssignNodes cannot happen. + importAssignNodes.head + end createTransformedImport + + // Used for assign statements, for loop target assignment and + // for comprehension target assignment. + // TODO handle Starred target + protected def createValueToTargetsDecomposition( + targets: Iterable[ast.iexpr], + valueNode: NewNode, + lineAndColumn: LineAndColumn + ): Iterable[NewNode] = + if + targets.size == 1 && + !targets.head.isInstanceOf[ast.Tuple] && + !targets.head.isInstanceOf[ast.List] + then + // No lowering or wrapping in a block is required if we have a single target and + // no decomposition. + val targetNode = convert(targets.head) + + Iterable.single(createAssignment(targetNode, valueNode, lineAndColumn)) + else + // Lowering of x, (y,z) = a = b = c: + // Note: No surrounding block is created. This is the duty of the caller. + // tmp = c + // x = tmp[0] + // y = tmp[1][0] + // z = tmp[1][1] + // a = c + // b = c + // Lowering of for x, (y, z) in c: + // tmp = c + // x = tmp[0] + // y = tmp[1][0] + // z = tmp[1][1] + val tmpVariableName = getUnusedName() + + val tmpVariableAssignNode = + createAssignmentToIdentifier(tmpVariableName, valueNode, lineAndColumn) + + val loweredAssignNodes = mutable.ArrayBuffer.empty[NewNode] + loweredAssignNodes.append(tmpVariableAssignNode) + + targets.foreach { target => + val targetWithAccessChains = getTargetsWithAccessChains(target) + targetWithAccessChains.foreach { case (trgt, accessChain) => + val targetNode = convert(trgt) + val tmpIdentifierNode = + createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val indexTmpIdentifierNode = + createIndexAccessChain(tmpIdentifierNode, accessChain, lineAndColumn) + + val targetAssignNode = + createAssignment(targetNode, indexTmpIdentifierNode, lineAndColumn) + loweredAssignNodes.append(targetAssignNode) + } + } + loweredAssignNodes + + protected def getTargetsWithAccessChains(target: ast.iexpr): Iterable[(ast.iexpr, List[Int])] = + val result = mutable.ArrayBuffer.empty[(ast.iexpr, List[Int])] + getTargetsInternal(target, Nil) + + def getTargetsInternal(target: ast.iexpr, indexChain: List[Int]): Unit = + target match + case tuple: ast.Tuple => + var index = 0 + tuple.elts.foreach { element => + getTargetsInternal(element, index :: indexChain) + index += 1 + } + case list: ast.List => + var index = 0 + list.elts.foreach { element => + getTargetsInternal(element, index :: indexChain) + index += 1 + } + case _ => + result.append((target, indexChain)) + + result + end getTargetsWithAccessChains + + protected def createComprehensionLowering( + tmpVariableName: String, + containerInitAssignNode: NewNode, + innerMostLoopNode: NewNode, + comprehensions: Iterable[ast.Comprehension], + lineAndColumn: LineAndColumn + ): NewNode = + val specialTargetLocals = mutable.ArrayBuffer.empty[NewLocal] + + // Innermost generator is transformed first and becomes the body of the + // generator one layer up. The body of the innermost generator is the + // list comprehensions element expression wrapped in an tmp.append() call. + val nestedLoopBlockNode = + comprehensions.foldRight(innerMostLoopNode) { case (comprehension, loopBodyNode) => + extractComprehensionSpecialVariableNames(comprehension.target).foreach { name => + // For the target names we need to create special scoped variables. + val localNode = nodeBuilder.localNode(name.id, None) + specialTargetLocals.append(localNode) + contextStack.addSpecialVariable(localNode) + } + createForLowering( + comprehension.target, + comprehension.iter, + comprehension.ifs, + Iterable.single(loopBodyNode), + Iterable.empty, + comprehension.is_async, + lineAndColumn + ) + } + + val returnIdentifierNode = createIdentifierNode(tmpVariableName, Load, lineAndColumn) + + val blockNode = + createBlock( + containerInitAssignNode :: nestedLoopBlockNode :: returnIdentifierNode :: Nil, + lineAndColumn + ) + + addAstChildNodes(blockNode, 1, specialTargetLocals) + + blockNode + end createComprehensionLowering + + // Extracts plain names, starred names and name or starred name elements from tuples and lists. + private def extractComprehensionSpecialVariableNames(target: ast.iexpr): Iterable[ast.Name] = + target match + case name: ast.Name => + name :: Nil + case starred: ast.Starred => + extractComprehensionSpecialVariableNames(starred.value) + case tuple: ast.Tuple => + tuple.elts.flatMap(extractComprehensionSpecialVariableNames) + case list: ast.List => + list.elts.flatMap(extractComprehensionSpecialVariableNames) + case _ => + Nil + + protected def createBlock( + blockElements: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + val blockCode = blockElements.map(codeOf).mkString("\n") + val blockNode = nodeBuilder.blockNode(blockCode, lineAndColumn) + + val orderIndex = new AutoIncIndex(1) + addAstChildNodes(blockNode, orderIndex, blockElements) + + blockNode + + protected def createCall( + receiverNode: NewNode, + name: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewCall = + val code = codeOf(receiverNode) + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) + + edgeBuilder.astEdge(receiverNode, callNode, 0) + edgeBuilder.receiverEdge(receiverNode, callNode) + + var index = 1 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, order = index) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex = index) index += 1 - } - case list: ast.List => - var index = 0 - list.elts.foreach { element => - getTargetsInternal(element, index :: indexChain) + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = index) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) index += 1 - } - case _ => - result.append((target, indexChain)) - } - } - - result - } - - protected def createComprehensionLowering( - tmpVariableName: String, - containerInitAssignNode: NewNode, - innerMostLoopNode: NewNode, - comprehensions: Iterable[ast.Comprehension], - lineAndColumn: LineAndColumn - ): NewNode = { - val specialTargetLocals = mutable.ArrayBuffer.empty[NewLocal] - - // Innermost generator is transformed first and becomes the body of the - // generator one layer up. The body of the innermost generator is the - // list comprehensions element expression wrapped in an tmp.append() call. - val nestedLoopBlockNode = - comprehensions.foldRight(innerMostLoopNode) { case (comprehension, loopBodyNode) => - extractComprehensionSpecialVariableNames(comprehension.target).foreach { name => - // For the target names we need to create special scoped variables. - val localNode = nodeBuilder.localNode(name.id, None) - specialTargetLocals.append(localNode) - contextStack.addSpecialVariable(localNode) } - createForLowering( - comprehension.target, - comprehension.iter, - comprehension.ifs, - Iterable.single(loopBodyNode), - Iterable.empty, - comprehension.is_async, + + callNode + end createCall + + protected def createInstanceCall( + receiverNode: NewNode, + instanceNode: NewNode, + name: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewCall = + val code = codeOf(receiverNode) + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) + + edgeBuilder.astEdge(receiverNode, callNode, 0) + edgeBuilder.receiverEdge(receiverNode, callNode) + edgeBuilder.astEdge(instanceNode, callNode, 1) + edgeBuilder.argumentEdge(instanceNode, callNode, 0) + + var argIndex = 1 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, argIndex + 1) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) + argIndex += 1 + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = argIndex + 1) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) + argIndex += 1 + } + + callNode + end createInstanceCall + + // NOTE if xMayHaveSideEffects == false, function x must return a distinct + // tree/node for each invocation!!! + // Otherwise the same tree/node may get placed in different places of the AST + // which is invalid and in this concrete case here triggered setting the + // argumentIndex twice once for x being receiver and once for x being the + // instance. + // If x may have side effects we lower as follows: x.y() => + // { + // tmp = x + // CALL(recv = tmp.y, inst = tmp, args=) + // } + protected def createXDotYCall( + x: () => NewNode, + y: String, + xMayHaveSideEffects: Boolean, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewNode = + if xMayHaveSideEffects then + val tmpVarName = getUnusedName() + val tmpAssignCall = createAssignmentToIdentifier(tmpVarName, x(), lineAndColumn) + val receiverNode = + createFieldAccess( + createIdentifierNode(tmpVarName, Load, lineAndColumn), + y, + lineAndColumn + ) + val instanceNode = createIdentifierNode(tmpVarName, Load, lineAndColumn) + val instanceCallNode = + createInstanceCall( + receiverNode, + instanceNode, + y, + lineAndColumn, + argumentNodes, + keywordArguments + ) + createBlock(tmpAssignCall :: instanceCallNode :: Nil, lineAndColumn) + else + val receiverNode = createFieldAccess(x(), y, lineAndColumn) + createInstanceCall(receiverNode, x(), y, lineAndColumn, argumentNodes, keywordArguments) + + // NOTE: The argument indicies start from 0! + protected def createStaticCall( + name: String, + methodFullName: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewNode = + val code = name + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + var argIndex = 0 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, argIndex) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) + argIndex += 1 + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = argIndex) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) + argIndex += 1 + } + + callNode + end createStaticCall + + protected def createNAryOperatorCall( + opCodeAndFullName: () => (String, String), + operands: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + + val (operatorCode, methodFullName) = opCodeAndFullName() + val code = + operands.map(operandNode => codeOf(operandNode)).mkString(" " + operatorCode + " ") + val callNode = + nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, operands) + + callNode + + protected def createBinaryOperatorCall( + lhsNode: NewNode, + opCodeAndFullName: () => (String, String), + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewCall = + val (opCode, opFullName) = opCodeAndFullName() + + val code = codeOf(lhsNode) + " " + opCode + " " + codeOf(rhsNode) + val callNode = + nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + callNode + + protected def createLiteralOperatorCall( + codeStart: String, + codeEnd: String, + opFullName: String, + lineAndColumn: LineAndColumn, + operands: NewNode* + ): NewCall = + val code = operands.map(codeOf).mkString(codeStart, ", ", codeEnd) + val callNode = + nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, operands) + + callNode + + protected def createStarredUnpackOperatorCall( + unpackOperand: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = "*" + codeOf(unpackOperand) + val callNode = nodeBuilder.callNode( + code, + ".starredUnpack", + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, unpackOperand) + callNode + + protected def createAssignment( + lhsNode: NewNode, + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(lhsNode) + " = " + codeOf(rhsNode) + val callNode = nodeBuilder.callNode( + code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + // Do not include imports or function pointers + if !codeOf(rhsNode).startsWith("import(") && codeOf(rhsNode) != s"def ${codeOf(lhsNode)}(...)" + then + contextStack.considerAsGlobalVariable(lhsNode) + + callNode + end createAssignment + + protected def createAssignmentToIdentifier( + identifierName: String, + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val identifierNode = createIdentifierNode(identifierName, Store, lineAndColumn) + createAssignment(identifierNode, rhsNode, lineAndColumn) + + protected def createAugAssignment( + lhsNode: NewNode, + operatorCode: String, + rhsNode: NewNode, + operatorFullName: String, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(lhsNode) + " " + operatorCode + " " + codeOf(rhsNode) + val callNode = nodeBuilder.callNode( + code, + operatorFullName, + DispatchTypes.STATIC_DISPATCH, lineAndColumn ) - } - - val returnIdentifierNode = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - - val blockNode = - createBlock(containerInitAssignNode :: nestedLoopBlockNode :: returnIdentifierNode :: Nil, lineAndColumn) - - addAstChildNodes(blockNode, 1, specialTargetLocals) - - blockNode - } - - // Extracts plain names, starred names and name or starred name elements from tuples and lists. - private def extractComprehensionSpecialVariableNames(target: ast.iexpr): Iterable[ast.Name] = { - target match { - case name: ast.Name => - name :: Nil - case starred: ast.Starred => - extractComprehensionSpecialVariableNames(starred.value) - case tuple: ast.Tuple => - tuple.elts.flatMap(extractComprehensionSpecialVariableNames) - case list: ast.List => - list.elts.flatMap(extractComprehensionSpecialVariableNames) - case _ => - Nil - } - } - - protected def createBlock(blockElements: Iterable[NewNode], lineAndColumn: LineAndColumn): NewNode = { - val blockCode = blockElements.map(codeOf).mkString("\n") - val blockNode = nodeBuilder.blockNode(blockCode, lineAndColumn) - - val orderIndex = new AutoIncIndex(1) - addAstChildNodes(blockNode, orderIndex, blockElements) - - blockNode - } - - protected def createCall( - receiverNode: NewNode, - name: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewCall = { - val code = codeOf(receiverNode) + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if (argumentNodes.nonEmpty && keywordArguments.nonEmpty) ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) - - edgeBuilder.astEdge(receiverNode, callNode, 0) - edgeBuilder.receiverEdge(receiverNode, callNode) - - var index = 1 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, order = index) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex = index) - index += 1 - } - - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = index) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - index += 1 - } - - callNode - } - - protected def createInstanceCall( - receiverNode: NewNode, - instanceNode: NewNode, - name: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewCall = { - val code = codeOf(receiverNode) + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if (argumentNodes.nonEmpty && keywordArguments.nonEmpty) ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) - - edgeBuilder.astEdge(receiverNode, callNode, 0) - edgeBuilder.receiverEdge(receiverNode, callNode) - edgeBuilder.astEdge(instanceNode, callNode, 1) - edgeBuilder.argumentEdge(instanceNode, callNode, 0) - - var argIndex = 1 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, argIndex + 1) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) - argIndex += 1 - } - - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = argIndex + 1) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - argIndex += 1 - } - - callNode - } - - // NOTE if xMayHaveSideEffects == false, function x must return a distinct - // tree/node for each invocation!!! - // Otherwise the same tree/node may get placed in different places of the AST - // which is invalid and in this concrete case here triggered setting the - // argumentIndex twice once for x being receiver and once for x being the - // instance. - // If x may have side effects we lower as follows: x.y() => - // { - // tmp = x - // CALL(recv = tmp.y, inst = tmp, args=) - // } - protected def createXDotYCall( - x: () => NewNode, - y: String, - xMayHaveSideEffects: Boolean, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewNode = { - if (xMayHaveSideEffects) { - val tmpVarName = getUnusedName() - val tmpAssignCall = createAssignmentToIdentifier(tmpVarName, x(), lineAndColumn) - val receiverNode = - createFieldAccess(createIdentifierNode(tmpVarName, Load, lineAndColumn), y, lineAndColumn) - val instanceNode = createIdentifierNode(tmpVarName, Load, lineAndColumn) - val instanceCallNode = - createInstanceCall(receiverNode, instanceNode, y, lineAndColumn, argumentNodes, keywordArguments) - createBlock(tmpAssignCall :: instanceCallNode :: Nil, lineAndColumn) - } else { - val receiverNode = createFieldAccess(x(), y, lineAndColumn) - createInstanceCall(receiverNode, x(), y, lineAndColumn, argumentNodes, keywordArguments) - } - } - - // NOTE: The argument indicies start from 0! - protected def createStaticCall( - name: String, - methodFullName: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewNode = { - val code = name + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if (argumentNodes.nonEmpty && keywordArguments.nonEmpty) ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - var argIndex = 0 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, argIndex) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) - argIndex += 1 - } - - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = argIndex) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - argIndex += 1 - } - - callNode - } - - protected def createNAryOperatorCall( - opCodeAndFullName: () => (String, String), - operands: Iterable[NewNode], - lineAndColumn: LineAndColumn - ): NewNode = { - - val (operatorCode, methodFullName) = opCodeAndFullName() - val code = operands.map(operandNode => codeOf(operandNode)).mkString(" " + operatorCode + " ") - val callNode = nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, operands) - - callNode - } - - protected def createBinaryOperatorCall( - lhsNode: NewNode, - opCodeAndFullName: () => (String, String), - rhsNode: NewNode, - lineAndColumn: LineAndColumn - ): NewCall = { - val (opCode, opFullName) = opCodeAndFullName() - - val code = codeOf(lhsNode) + " " + opCode + " " + codeOf(rhsNode) - val callNode = nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - callNode - } - - protected def createLiteralOperatorCall( - codeStart: String, - codeEnd: String, - opFullName: String, - lineAndColumn: LineAndColumn, - operands: NewNode* - ): NewCall = { - val code = operands.map(codeOf).mkString(codeStart, ", ", codeEnd) - val callNode = nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, operands) - - callNode - } - - protected def createStarredUnpackOperatorCall(unpackOperand: NewNode, lineAndColumn: LineAndColumn): NewNode = { - val code = "*" + codeOf(unpackOperand) - val callNode = nodeBuilder.callNode(code, ".starredUnpack", DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, unpackOperand) - callNode - } - - protected def createAssignment(lhsNode: NewNode, rhsNode: NewNode, lineAndColumn: LineAndColumn): NewNode = { - val code = codeOf(lhsNode) + " = " + codeOf(rhsNode) - val callNode = nodeBuilder.callNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - // Do not include imports or function pointers - if (!codeOf(rhsNode).startsWith("import(") && codeOf(rhsNode) != s"def ${codeOf(lhsNode)}(...)") - contextStack.considerAsGlobalVariable(lhsNode) - - callNode - } - - protected def createAssignmentToIdentifier( - identifierName: String, - rhsNode: NewNode, - lineAndColumn: LineAndColumn - ): NewNode = { - val identifierNode = createIdentifierNode(identifierName, Store, lineAndColumn) - createAssignment(identifierNode, rhsNode, lineAndColumn) - } - - protected def createAugAssignment( - lhsNode: NewNode, - operatorCode: String, - rhsNode: NewNode, - operatorFullName: String, - lineAndColumn: LineAndColumn - ): NewNode = { - val code = codeOf(lhsNode) + " " + operatorCode + " " + codeOf(rhsNode) - val callNode = nodeBuilder.callNode(code, operatorFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - - callNode - } - - // Always use this method to create an identifier node instead of - // nodeBuilder.identifierNode() directly to avoid missing to add - // the variable reference. - protected def createIdentifierNode( - name: String, - memOp: MemoryOperation, - lineAndColumn: LineAndColumn - ): NewIdentifier = { - val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn) - contextStack.addVariableReference(identifierNode, memOp) - identifierNode - } - - protected def createIndexAccess(baseNode: NewNode, indexNode: NewNode, lineAndColumn: LineAndColumn): NewNode = { - val code = codeOf(baseNode) + "[" + codeOf(indexNode) + "]" - val indexAccessNode = - nodeBuilder.callNode(code, Operators.indexAccess, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(indexAccessNode, 1, baseNode, indexNode) - - indexAccessNode - } - - protected def createIndexAccessChain( - rootNode: NewNode, - accessChain: List[Int], - lineAndColumn: LineAndColumn - ): NewNode = { - accessChain match { - case accessIndex :: tail => - val baseNode = createIndexAccessChain(rootNode, tail, lineAndColumn) - val indexNode = nodeBuilder.numberLiteralNode(accessIndex, lineAndColumn) - - createIndexAccess(baseNode, indexNode, lineAndColumn) - case Nil => - rootNode - } - } - - protected def createFieldAccess(baseNode: NewNode, fieldName: String, lineAndColumn: LineAndColumn): NewCall = { - val fieldIdNode = nodeBuilder.fieldIdentifierNode(fieldName, lineAndColumn) - - val code = codeOf(baseNode) + "." + codeOf(fieldIdNode) - val callNode = nodeBuilder.callNode(code, Operators.fieldAccess, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, baseNode, fieldIdNode) - callNode - } - - protected def createTypeRef(typeName: String, typeFullName: String, lineAndColumn: LineAndColumn): NewTypeRef = { - nodeBuilder.typeRefNode("class " + typeName + "(...)", typeFullName, lineAndColumn) - } - - protected def createBinding(methodNode: NewMethod, typeDeclNode: NewTypeDecl): NewBinding = { - val bindingNode = nodeBuilder.bindingNode() - edgeBuilder.bindsEdge(bindingNode, typeDeclNode) - edgeBuilder.refEdge(methodNode, bindingNode) - - bindingNode - } - - protected def createReturn( - returnExprOption: Option[NewNode], - codeOption: Option[String], - lineAndColumn: LineAndColumn - ): NewReturn = { - val code = codeOption.getOrElse { - returnExprOption match { - case Some(returnExpr) => - "return " + codeOf(returnExpr) - case None => - "return" - } - } - val returnNode = nodeBuilder.returnNode(code, lineAndColumn) - returnExprOption.foreach { returnExpr => addAstChildrenAsArguments(returnNode, 1, returnExpr) } - - returnNode - } - - protected def addAstChildNodes(parentNode: NewNode, startIndex: AutoIncIndex, childNodes: Iterable[NewNode]): Unit = { - childNodes.foreach { childNode => - val orderIndex = startIndex.getAndInc - edgeBuilder.astEdge(childNode, parentNode, orderIndex) - } - } - - protected def addAstChildNodes(parentNode: NewNode, startIndex: Int, childNodes: Iterable[NewNode]): Unit = { - addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) - } - - protected def addAstChildNodes(parentNode: NewNode, startIndex: AutoIncIndex, childNodes: NewNode*): Unit = { - addAstChildNodes(parentNode, startIndex, childNodes) - } - - protected def addAstChildNodes(parentNode: NewNode, startIndex: Int, childNodes: NewNode*): Unit = { - addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) - } - - protected def addAstChildrenAsArguments( - parentNode: NewNode, - startIndex: AutoIncIndex, - childNodes: Iterable[NewNode] - ): Unit = { - childNodes.foreach { childNode => - val orderAndArgIndex = startIndex.getAndInc - edgeBuilder.astEdge(childNode, parentNode, orderAndArgIndex) - edgeBuilder.argumentEdge(childNode, parentNode, orderAndArgIndex) - } - } - - protected def addAstChildrenAsArguments(parentNode: NewNode, startIndex: Int, childNodes: Iterable[NewNode]): Unit = { - addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) - } - - protected def addAstChildrenAsArguments(parentNode: NewNode, startIndex: AutoIncIndex, childNodes: NewNode*): Unit = { - addAstChildrenAsArguments(parentNode, startIndex, childNodes) - } - - protected def addAstChildrenAsArguments(parentNode: NewNode, startIndex: Int, childNodes: NewNode*): Unit = { - addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) - } -} + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + + callNode + + // Always use this method to create an identifier node instead of + // nodeBuilder.identifierNode() directly to avoid missing to add + // the variable reference. + protected def createIdentifierNode( + name: String, + memOp: MemoryOperation, + lineAndColumn: LineAndColumn + ): NewIdentifier = + val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn) + contextStack.addVariableReference(identifierNode, memOp) + identifierNode + + protected def createIndexAccess( + baseNode: NewNode, + indexNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(baseNode) + "[" + codeOf(indexNode) + "]" + val indexAccessNode = + nodeBuilder.callNode( + code, + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(indexAccessNode, 1, baseNode, indexNode) + + indexAccessNode + + protected def createIndexAccessChain( + rootNode: NewNode, + accessChain: List[Int], + lineAndColumn: LineAndColumn + ): NewNode = + accessChain match + case accessIndex :: tail => + val baseNode = createIndexAccessChain(rootNode, tail, lineAndColumn) + val indexNode = nodeBuilder.numberLiteralNode(accessIndex, lineAndColumn) + + createIndexAccess(baseNode, indexNode, lineAndColumn) + case Nil => + rootNode + + protected def createFieldAccess( + baseNode: NewNode, + fieldName: String, + lineAndColumn: LineAndColumn + ): NewCall = + val fieldIdNode = nodeBuilder.fieldIdentifierNode(fieldName, lineAndColumn) + + val code = codeOf(baseNode) + "." + codeOf(fieldIdNode) + val callNode = nodeBuilder.callNode( + code, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, baseNode, fieldIdNode) + callNode + + protected def createTypeRef( + typeName: String, + typeFullName: String, + lineAndColumn: LineAndColumn + ): NewTypeRef = + nodeBuilder.typeRefNode("class " + typeName + "(...)", typeFullName, lineAndColumn) + + protected def createBinding(methodNode: NewMethod, typeDeclNode: NewTypeDecl): NewBinding = + val bindingNode = nodeBuilder.bindingNode() + edgeBuilder.bindsEdge(bindingNode, typeDeclNode) + edgeBuilder.refEdge(methodNode, bindingNode) + + bindingNode + + protected def createReturn( + returnExprOption: Option[NewNode], + codeOption: Option[String], + lineAndColumn: LineAndColumn + ): NewReturn = + val code = codeOption.getOrElse { + returnExprOption match + case Some(returnExpr) => + "return " + codeOf(returnExpr) + case None => + "return" + } + val returnNode = nodeBuilder.returnNode(code, lineAndColumn) + returnExprOption.foreach { returnExpr => + addAstChildrenAsArguments(returnNode, 1, returnExpr) + } + + returnNode + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: Iterable[NewNode] + ): Unit = + childNodes.foreach { childNode => + val orderIndex = startIndex.getAndInc + edgeBuilder.astEdge(childNode, parentNode, orderIndex) + } + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: Int, + childNodes: Iterable[NewNode] + ): Unit = + addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: NewNode* + ): Unit = + addAstChildNodes(parentNode, startIndex, childNodes) + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: Int, + childNodes: NewNode* + ): Unit = + addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: Iterable[NewNode] + ): Unit = + childNodes.foreach { childNode => + val orderAndArgIndex = startIndex.getAndInc + edgeBuilder.astEdge(childNode, parentNode, orderAndArgIndex) + edgeBuilder.argumentEdge(childNode, parentNode, orderAndArgIndex) + } + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: Int, + childNodes: Iterable[NewNode] + ): Unit = + addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: NewNode* + ): Unit = + addAstChildrenAsArguments(parentNode, startIndex, childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: Int, + childNodes: NewNode* + ): Unit = + addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) +end PythonAstVisitorHelpers diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala index 9e2d74ec..473c4f70 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala @@ -3,12 +3,10 @@ package io.appthreat.pysrc2cpg import io.appthreat.x2cpg.passes.frontend.XInheritanceFullNamePass import io.shiftleft.codepropertygraph.Cpg -/** Using some basic heuristics, will try to resolve type full names from types found within the CPG. Requires - * ImportPass as a pre-requisite. +/** Using some basic heuristics, will try to resolve type full names from types found within the + * CPG. Requires ImportPass as a pre-requisite. */ -class PythonInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg) { +class PythonInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg): - override val moduleName: String = "" - override val fileExt: String = ".py" - -} + override val moduleName: String = "" + override val fileExt: String = ".py" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala index 049c9219..10830407 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala @@ -6,25 +6,24 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.semanticcpg.language.* -class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { +class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override def calls: Iterator[Call] = super.calls.nameNot("^(import).*") + override def calls: Iterator[Call] = super.calls.nameNot("^(import).*") - override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map { - // Python call from a type - case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__" - // Python call from a function pointer - case typ => typ - } - - override def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = { - if (methodNames.sizeIs == 1) { - super.setCallees(call, methodNames, builder) - } else if (methodNames.sizeIs > 1) { - val nonDummyMethodNames = - methodNames.filterNot(x => isDummyType(x) || x.startsWith(PythonAstVisitor.builtinPrefix + "None")) - super.setCallees(call, nonDummyMethodNames, builder) + override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map { + // Python call from a type + case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__" + // Python call from a function pointer + case typ => typ } - } -} + override def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = + if methodNames.sizeIs == 1 then + super.setCallees(call, methodNames, builder) + else if methodNames.sizeIs > 1 then + val nonDummyMethodNames = + methodNames.filterNot(x => + isDummyType(x) || x.startsWith(PythonAstVisitor.builtinPrefix + "None") + ) + super.setCallees(call, nonDummyMethodNames, builder) +end PythonTypeHintCallLinker diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala index ef19ca95..17ab2a3f 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala @@ -1,219 +1,247 @@ package io.appthreat.pysrc2cpg -import io.appthreat.x2cpg.passes.frontend._ +import io.appthreat.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder class PythonTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { + extends XTypeRecoveryPass[File](cpg, config): - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new PythonTypeRecovery(cpg, state) -} + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = + new PythonTypeRecovery(cpg, state) -private class PythonTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { +private class PythonTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) + extends XTypeRecovery[File](cpg, state): - override def compilationUnit: Iterator[File] = cpg.file.iterator + override def compilationUnit: Iterator[File] = cpg.file.iterator - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForPythonFile(cpg, unit, builder, state.copy(config = newConfig)) - } - -} + override def generateRecoveryForCompilationUnitTask( + unit: File, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[File] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForPythonFile(cpg, unit, builder, state.copy(config = newConfig)) /** Performs type recovery from the root of a compilation unit level */ -private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { - - /** Replaces the `this` prefix with the Pythonic `self` prefix for instance methods of functions local to this - * compilation unit. - */ - private def fromNodeToLocalPythonKey(node: AstNode): Option[LocalKey] = - node match { - case n: Method => Option(CallAlias(n.name, Option("self"))) - case _ => SBKey.fromNodeToLocalKey(node) - } - - override val symbolTable: SymbolTable[LocalKey] = new SymbolTable[LocalKey](fromNodeToLocalPythonKey) - - override def visitImport(i: Import): Unit = { - if (i.importedAs.isDefined && i.importedEntity.isDefined) { - import io.appthreat.x2cpg.passes.frontend.ImportsPass._ - - val entityName = i.importedAs.get - i.call.tag.flatMap(ResolvedImport.tagToResolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => symbolTable.put(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => symbolTable.put(LocalVar(entityName), fullName) - case ResolvedMember(basePath, memberName, _) => - val memberTypes = cpg.typeDecl - .fullNameExact(basePath) - .member - .nameExact(memberName) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot(_ == "ANY") - .toSet - symbolTable.put(LocalVar(entityName), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.put(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.put(LocalVar(entityName), fullName) - case UnknownImport(path, _) => - symbolTable.put(CallAlias(entityName), path) - symbolTable.put(LocalVar(entityName), path) - } - } - } - - override def visitAssignments(a: OpNodes.Assignment): Set[String] = { - a.argumentOut.l match { - case List(i: Identifier, c: Call) if c.name.isBlank && c.signature.isBlank => - // This is usually some decorator wrapper - c.argument.isMethodRef.headOption match { - case Some(mRef) => visitIdentifierAssignedToMethodRef(i, mRef) - case None => super.visitAssignments(a) - } - case _ => super.visitAssignments(a) - } - } - - /** Determines if a function call is a constructor by following the heuristic that Python classes are typically - * camel-case and start with an upper-case character. - */ - override def isConstructor(c: Call): Boolean = - isConstructor(c.name) && c.code.endsWith(")") - - override protected def isConstructor(name: String): Boolean = - name.nonEmpty && name.charAt(0).isUpper - - /** If the parent method is module then it can be used as a field. - */ - override def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate(i.id(), i.method.name.matches("(|__init__)") || super.isField(i)) - - override def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { - operation match { - case ".listLiteral" => associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}list")) - case ".tupleLiteral" => associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}tuple")) - case ".dictLiteral" => associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}dict")) - case ".setLiteral" => associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}set")) - case Operators.conditional => associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}bool")) - case _ => super.visitIdentifierAssignedToOperator(i, c, operation) - } - } - - override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { - val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) - associateTypes(i, constructorPaths) - } - - override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { - // Ignore legacy import representation - if (c.name.equals("import")) Set.empty - // Stop custom annotation representation from hitting superclass - else if (c.name.isBlank) Set.empty - else super.visitIdentifierAssignedToCall(i, c) - } - - override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = { - val fieldParents = getFieldParents(fa) - fa.astChildren.l match { - case List(base: Identifier, fi: FieldIdentifier) if base.name.equals("self") && fieldParents.nonEmpty => - val referencedFields = cpg.typeDecl.fullNameExact(fieldParents.toSeq: _*).member.nameExact(fi.canonicalName) - val globalTypes = - referencedFields.flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName).filterNot(_ == Constants.ANY).toSet - associateTypes(i, globalTypes) - case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) - } - } - - override def getTypesFromCall(c: Call): Set[String] = c.name match { - case ".listLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}list") - case ".tupleLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}tuple") - case ".dictLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}dict") - case ".setLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}set") - case _ => super.getTypesFromCall(c) - } - - override def getFieldParents(fa: FieldAccess): Set[String] = { - if (fa.method.name == "") { - Set(fa.method.fullName) - } else if (fa.method.typeDecl.nonEmpty) { - val parentTypes = fa.method.typeDecl.fullName.toSet - val baseTypeFullNames = cpg.typeDecl.fullNameExact(parentTypes.toSeq: _*).inheritsFromTypeFullName.toSet - (parentTypes ++ baseTypeFullNames).filterNot(_.matches("(?i)(any|object)")) - } else { - super.getFieldParents(fa) - } - } - - private def isPyString(s: String): Boolean = - (s.startsWith("\"") || s.startsWith("'")) && (s.endsWith("\"") || s.endsWith("'")) - - override def getLiteralType(l: Literal): Set[String] = { - (l.code match { - case code if code.toIntOption.isDefined => Some(s"${PythonAstVisitor.builtinPrefix}int") - case code if code.toDoubleOption.isDefined => Some(s"${PythonAstVisitor.builtinPrefix}float") - case code if "True".equals(code) || "False".equals(code) => Some(s"${PythonAstVisitor.builtinPrefix}bool") - case code if code.equals("None") => Some(s"${PythonAstVisitor.builtinPrefix}None") - case code if isPyString(code) => Some(s"${PythonAstVisitor.builtinPrefix}str") - case _ => None - }).toSet - } - - override def createCallFromIdentifierTypeFullName(typeFullName: String, callName: String): String = { - lazy val tName = typeFullName.split("\\.").lastOption.getOrElse(typeFullName) - typeFullName match { - case t if t.matches(".*(<\\w+>)$") => super.createCallFromIdentifierTypeFullName(typeFullName, callName) - case t if t.matches(".*\\.<(member|returnValue|indexAccess)>(\\(.*\\))?") => - super.createCallFromIdentifierTypeFullName(typeFullName, callName) - case t if isConstructor(tName) => - Seq(t, callName).mkString(pathSep.toString) - case _ => super.createCallFromIdentifierTypeFullName(typeFullName, callName) - } - } - - override protected def postSetTypeInformation(): Unit = - cu.typeDecl - .map(t => t -> t.inheritsFromTypeFullName.partition(itf => symbolTable.contains(LocalVar(itf)))) - .foreach { case (t, (identifierTypes, otherTypes)) => - val existingTypes = (identifierTypes ++ otherTypes).distinct - val resolvedTypes = identifierTypes.map(LocalVar.apply).flatMap(symbolTable.get) - if (existingTypes != resolvedTypes && resolvedTypes.nonEmpty) { - state.changesWereMade.compareAndExchange(false, true) - builder.setNodeProperty(t, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, resolvedTypes) +private class RecoverForPythonFile( + cpg: Cpg, + cu: File, + builder: DiffGraphBuilder, + state: XTypeRecoveryState +) extends RecoverForXCompilationUnit[File](cpg, cu, builder, state): + + /** Replaces the `this` prefix with the Pythonic `self` prefix for instance methods of functions + * local to this compilation unit. + */ + private def fromNodeToLocalPythonKey(node: AstNode): Option[LocalKey] = + node match + case n: Method => Option(CallAlias(n.name, Option("self"))) + case _ => SBKey.fromNodeToLocalKey(node) + + override val symbolTable: SymbolTable[LocalKey] = + new SymbolTable[LocalKey](fromNodeToLocalPythonKey) + + override def visitImport(i: Import): Unit = + if i.importedAs.isDefined && i.importedEntity.isDefined then + import io.appthreat.x2cpg.passes.frontend.ImportsPass.* + + val entityName = i.importedAs.get + i.call.tag.flatMap(ResolvedImport.tagToResolvedImport).foreach { + case ResolvedMethod(fullName, alias, receiver, _) => + symbolTable.put(CallAlias(alias, receiver), fullName) + case ResolvedTypeDecl(fullName, _) => + symbolTable.put(LocalVar(entityName), fullName) + case ResolvedMember(basePath, memberName, _) => + val memberTypes = cpg.typeDecl + .fullNameExact(basePath) + .member + .nameExact(memberName) + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .filterNot(_ == "ANY") + .toSet + symbolTable.put(LocalVar(entityName), memberTypes) + case UnknownMethod(fullName, alias, receiver, _) => + symbolTable.put(CallAlias(alias, receiver), fullName) + case UnknownTypeDecl(fullName, _) => + symbolTable.put(LocalVar(entityName), fullName) + case UnknownImport(path, _) => + symbolTable.put(CallAlias(entityName), path) + symbolTable.put(LocalVar(entityName), path) + } + + override def visitAssignments(a: OpNodes.Assignment): Set[String] = + a.argumentOut.l match + case List(i: Identifier, c: Call) if c.name.isBlank && c.signature.isBlank => + // This is usually some decorator wrapper + c.argument.isMethodRef.headOption match + case Some(mRef) => visitIdentifierAssignedToMethodRef(i, mRef) + case None => super.visitAssignments(a) + case _ => super.visitAssignments(a) + + /** Determines if a function call is a constructor by following the heuristic that Python + * classes are typically camel-case and start with an upper-case character. + */ + override def isConstructor(c: Call): Boolean = + isConstructor(c.name) && c.code.endsWith(")") + + override protected def isConstructor(name: String): Boolean = + name.nonEmpty && name.charAt(0).isUpper + + /** If the parent method is module then it can be used as a field. + */ + override def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + i.method.name.matches("(|__init__)") || super.isField(i) + ) + + override def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case ".listLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}list")) + case ".tupleLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}tuple")) + case ".dictLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}dict")) + case ".setLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}set")) + case Operators.conditional => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}bool")) + case _ => super.visitIdentifierAssignedToOperator(i, c, operation) + + override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = + val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) + associateTypes(i, constructorPaths) + + override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + // Ignore legacy import representation + if c.name.equals("import") then Set.empty + // Stop custom annotation representation from hitting superclass + else if c.name.isBlank then Set.empty + else super.visitIdentifierAssignedToCall(i, c) + + override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = + val fieldParents = getFieldParents(fa) + fa.astChildren.l match + case List(base: Identifier, fi: FieldIdentifier) + if base.name.equals("self") && fieldParents.nonEmpty => + val referencedFields = cpg.typeDecl.fullNameExact( + fieldParents.toSeq* + ).member.nameExact(fi.canonicalName) + val globalTypes = + referencedFields.flatMap(m => + m.typeFullName +: m.dynamicTypeHintFullName + ).filterNot(_ == Constants.ANY).toSet + associateTypes(i, globalTypes) + case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) + + override def getTypesFromCall(c: Call): Set[String] = c.name match + case ".listLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}list") + case ".tupleLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}tuple") + case ".dictLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}dict") + case ".setLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}set") + case _ => super.getTypesFromCall(c) + + override def getFieldParents(fa: FieldAccess): Set[String] = + if fa.method.name == "" then + Set(fa.method.fullName) + else if fa.method.typeDecl.nonEmpty then + val parentTypes = fa.method.typeDecl.fullName.toSet + val baseTypeFullNames = + cpg.typeDecl.fullNameExact(parentTypes.toSeq*).inheritsFromTypeFullName.toSet + (parentTypes ++ baseTypeFullNames).filterNot(_.matches("(?i)(any|object)")) + else + super.getFieldParents(fa) + + private def isPyString(s: String): Boolean = + (s.startsWith("\"") || s.startsWith("'")) && (s.endsWith("\"") || s.endsWith("'")) + + override def getLiteralType(l: Literal): Set[String] = + (l.code match + case code if code.toIntOption.isDefined => Some(s"${PythonAstVisitor.builtinPrefix}int") + case code if code.toDoubleOption.isDefined => + Some(s"${PythonAstVisitor.builtinPrefix}float") + case code if "True".equals(code) || "False".equals(code) => + Some(s"${PythonAstVisitor.builtinPrefix}bool") + case code if code.equals("None") => Some(s"${PythonAstVisitor.builtinPrefix}None") + case code if isPyString(code) => Some(s"${PythonAstVisitor.builtinPrefix}str") + case _ => None + ).toSet + + override def createCallFromIdentifierTypeFullName( + typeFullName: String, + callName: String + ): String = + lazy val tName = typeFullName.split("\\.").lastOption.getOrElse(typeFullName) + typeFullName match + case t if t.matches(".*(<\\w+>)$") => + super.createCallFromIdentifierTypeFullName(typeFullName, callName) + case t if t.matches(".*\\.<(member|returnValue|indexAccess)>(\\(.*\\))?") => + super.createCallFromIdentifierTypeFullName(typeFullName, callName) + case t if isConstructor(tName) => + Seq(t, callName).mkString(pathSep.toString) + case _ => super.createCallFromIdentifierTypeFullName(typeFullName, callName) + + override protected def postSetTypeInformation(): Unit = + cu.typeDecl + .map(t => + t -> t.inheritsFromTypeFullName.partition(itf => + symbolTable.contains(LocalVar(itf)) + ) + ) + .foreach { case (t, (identifierTypes, otherTypes)) => + val existingTypes = (identifierTypes ++ otherTypes).distinct + val resolvedTypes = identifierTypes.map(LocalVar.apply).flatMap(symbolTable.get) + if existingTypes != resolvedTypes && resolvedTypes.nonEmpty then + state.changesWereMade.compareAndExchange(false, true) + builder.setNodeProperty( + t, + PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, + resolvedTypes + ) + } + + override def prepopulateSymbolTable(): Unit = + cu.ast.isMethodRef.where( + _.astSiblings.isIdentifier.nameExact("classmethod") + ).referencedMethod.foreach { + classMethod => + classMethod.parameter + .nameExact("cls") + .foreach { cls => + val clsPath = classMethod.typeDecl.fullName.toSet + symbolTable.put(LocalVar(cls.name), clsPath) + if cls.typeFullName == "ANY" then + builder.setNodeProperty( + cls, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + clsPath.toSeq + ) + } } - } - - override def prepopulateSymbolTable(): Unit = { - cu.ast.isMethodRef.where(_.astSiblings.isIdentifier.nameExact("classmethod")).referencedMethod.foreach { - classMethod => - classMethod.parameter - .nameExact("cls") - .foreach { cls => - val clsPath = classMethod.typeDecl.fullName.toSet - symbolTable.put(LocalVar(cls.name), clsPath) - if (cls.typeFullName == "ANY") - builder.setNodeProperty(cls, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, clsPath.toSeq) - } - } - super.prepopulateSymbolTable() - } - - override protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef, rec: Option[String]): Set[String] = - t.typ.referencedTypeDecl - .map(_.fullName.stripSuffix("")) - .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) - .headOption - .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) - -} + super.prepopulateSymbolTable() + end prepopulateSymbolTable + + override protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] + ): Set[String] = + t.typ.referencedTypeDecl + .map(_.fullName.stripSuffix("")) + .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) + .headOption + .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) +end RecoverForPythonFile diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala index 2adea304..b8619d25 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala @@ -3,32 +3,25 @@ package io.appthreat.pysrc2cpg.memop import io.appthreat.pythonparser.ast import scala.collection.mutable -class AstNodeToMemoryOperationMap { - private class IdentityHashWrapper(private val astNode: ast.iast) { - override def equals(o: Any): Boolean = { - o match { - case wrapper: IdentityHashWrapper => - astNode eq wrapper.astNode - case _ => - false - } - } +class AstNodeToMemoryOperationMap: + private class IdentityHashWrapper(private val astNode: ast.iast): + override def equals(o: Any): Boolean = + o match + case wrapper: IdentityHashWrapper => + astNode eq wrapper.astNode + case _ => + false - override def hashCode(): Int = { - System.identityHashCode(astNode) - } + override def hashCode(): Int = + System.identityHashCode(astNode) - override def toString: String = astNode.toString - } + override def toString: String = astNode.toString - private val astNodeToMemOp = mutable.HashMap.empty[IdentityHashWrapper, MemoryOperation] + private val astNodeToMemOp = mutable.HashMap.empty[IdentityHashWrapper, MemoryOperation] - def put(astNode: ast.iast, memOp: MemoryOperation): Unit = { - astNodeToMemOp.put(new IdentityHashWrapper(astNode), memOp) - } + def put(astNode: ast.iast, memOp: MemoryOperation): Unit = + astNodeToMemOp.put(new IdentityHashWrapper(astNode), memOp) - def get(astNode: ast.iast): Option[MemoryOperation] = { - astNodeToMemOp.get(new IdentityHashWrapper(astNode)) - } - -} + def get(astNode: ast.iast): Option[MemoryOperation] = + astNodeToMemOp.get(new IdentityHashWrapper(astNode)) +end AstNodeToMemoryOperationMap diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala index 1647420b..89665bf0 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala @@ -1,8 +1,7 @@ package io.appthreat.pysrc2cpg.memop -sealed trait MemoryOperation { - override def toString: String = getClass.getSimpleName -} +sealed trait MemoryOperation: + override def toString: String = getClass.getSimpleName object Store extends MemoryOperation object Load extends MemoryOperation object Del extends MemoryOperation diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala index 0225c00b..9306312c 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala @@ -2,594 +2,529 @@ package io.appthreat.pysrc2cpg.memop import io.appthreat.pythonparser.{AstVisitor, ast} import io.appthreat.pythonparser.ast.{ - AnnAssign, - Assert, - Assign, - AsyncFor, - AsyncFunctionDef, - AsyncWith, - AugAssign, - BinOp, - BoolOp, - Break, - ClassDef, - Continue, - Delete, - Dict, - ErrorStatement, - Expr, - For, - FormattedValue, - FunctionDef, - Global, - If, - IfExp, - Import, - ImportFrom, - JoinedString, - JoinedStringConstant, - Lambda, - Match, - MatchAs, - MatchCase, - MatchClass, - MatchMapping, - MatchOr, - MatchSequence, - MatchSingleton, - MatchStar, - MatchValue, - Module, - NamedExpr, - Nonlocal, - Pass, - Raise, - RaiseP2, - Return, - Try, - UnaryOp, - While, - With, - iast, - iexpr, - imod, - istmt + AnnAssign, + Assert, + Assign, + AsyncFor, + AsyncFunctionDef, + AsyncWith, + AugAssign, + BinOp, + BoolOp, + Break, + ClassDef, + Continue, + Delete, + Dict, + ErrorStatement, + Expr, + For, + FormattedValue, + FunctionDef, + Global, + If, + IfExp, + Import, + ImportFrom, + JoinedString, + JoinedStringConstant, + Lambda, + Match, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue, + Module, + NamedExpr, + Nonlocal, + Pass, + Raise, + RaiseP2, + Return, + Try, + UnaryOp, + While, + With, + iast, + iexpr, + imod, + istmt } import io.appthreat.pythonparser.ast.{ - FormattedValue, - JoinedString, - JoinedStringConstant, - MatchAs, - MatchCase, - MatchClass, - MatchMapping, - MatchOr, - MatchSequence, - MatchSingleton, - MatchStar, - MatchValue + FormattedValue, + JoinedString, + JoinedStringConstant, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue } import scala.collection.mutable -class MemoryOperationCalculator extends AstVisitor[Unit] { - private val stack = mutable.Stack.empty[MemoryOperation] - val astNodeToMemOp = new AstNodeToMemoryOperationMap() - val names = mutable.Set.empty[String] - - private def accept(astNode: iast): Unit = { - astNode.accept(this) - } - - private def accept(astNodes: Iterable[iast]): Unit = { - astNodes.foreach(accept) - } - - private def push(memOp: MemoryOperation): Unit = { - stack.push(memOp) - } - - private def pop(): Unit = { - stack.pop() - } - - override def visit(astNode: iast): Unit = ??? - - override def visit(mod: imod): Unit = ??? - - override def visit(module: Module): Unit = { - accept(module.stmts) - } - - override def visit(stmt: istmt): Unit = ??? - - override def visit(functionDef: FunctionDef): Unit = { - push(Load) - accept(functionDef.decorator_list) - accept(functionDef.args) - accept(functionDef.returns) - pop() - accept(functionDef.body) - } - - override def visit(functionDef: AsyncFunctionDef): Unit = { - push(Load) - accept(functionDef.decorator_list) - accept(functionDef.args) - accept(functionDef.returns) - pop() - accept(functionDef.body) - } - - override def visit(classDef: ClassDef): Unit = { - push(Load) - accept(classDef.decorator_list) - accept(classDef.bases) - accept(classDef.keywords) - pop() - accept(classDef.body) - } - - override def visit(ret: Return): Unit = { - push(Load) - accept(ret.value) - pop() - } - - override def visit(delete: Delete): Unit = { - push(Del) - accept(delete.targets) - pop() - } - - override def visit(assign: Assign): Unit = { - push(Store) - accept(assign.targets) - pop() - push(Load) - accept(assign.value) - pop() - } - - override def visit(annAssign: AnnAssign): Unit = { - push(Store) - accept(annAssign.target) - pop() - push(Load) - accept(annAssign.annotation) - accept(annAssign.value) - pop() - } - - override def visit(augAssign: AugAssign): Unit = { - push(Store) - accept(augAssign.target) - pop() - push(Load) - accept(augAssign.value) - pop() - } - - override def visit(forStmt: For): Unit = { - push(Store) - accept(forStmt.target) - pop() - push(Load) - accept(forStmt.iter) - pop() - accept(forStmt.body) - accept(forStmt.orelse) - } - - override def visit(forStmt: AsyncFor): Unit = { - push(Store) - accept(forStmt.target) - pop() - push(Load) - accept(forStmt.iter) - pop() - accept(forStmt.body) - accept(forStmt.orelse) - } - - override def visit(whileStmt: While): Unit = { - push(Load) - accept(whileStmt.test) - pop() - accept(whileStmt.body) - accept(whileStmt.orelse) - } - - override def visit(ifStmt: If): Unit = { - push(Load) - accept(ifStmt.test) - pop() - accept(ifStmt.body) - accept(ifStmt.orelse) - } - - override def visit(withStmt: With): Unit = { - accept(withStmt.items) - accept(withStmt.body) - } - - override def visit(withStmt: AsyncWith): Unit = { - accept(withStmt.items) - accept(withStmt.body) - } - - override def visit(matchStmt: Match): Unit = { - push(Load) - accept(matchStmt.subject) - accept(matchStmt.cases) - pop() - } - - override def visit(raise: Raise): Unit = { - push(Load) - accept(raise.exc) - accept(raise.cause) - pop() - } - - override def visit(tryStmt: Try): Unit = { - accept(tryStmt.body) - accept(tryStmt.handlers) - accept(tryStmt.orelse) - accept(tryStmt.finalbody) - } - - override def visit(assert: Assert): Unit = { - push(Load) - accept(assert.test) - accept(assert.msg) - pop() - } - - override def visit(importStmt: Import): Unit = {} - - override def visit(importFrom: ImportFrom): Unit = {} - - override def visit(global: Global): Unit = {} - - override def visit(nonlocal: Nonlocal): Unit = {} - - override def visit(expr: Expr): Unit = { - push(Load) - expr.value.accept(this) - pop() - } - - override def visit(pass: Pass): Unit = {} - - override def visit(break: Break): Unit = {} - - override def visit(continue: Continue): Unit = {} - - override def visit(raise: RaiseP2): Unit = { - push(Load) - accept(raise.typ) - accept(raise.inst) - accept(raise.tback) - pop() - } - - override def visit(errorStatement: ErrorStatement): Unit = {} - - override def visit(expr: iexpr): Unit = ??? - - override def visit(boolOp: BoolOp): Unit = { - accept(boolOp.values) - } - - override def visit(namedExpr: NamedExpr): Unit = { - push(Store) - accept(namedExpr.target) - pop() - accept(namedExpr.value) - } - - override def visit(binOp: BinOp): Unit = { - accept(binOp.left) - accept(binOp.right) - } - - override def visit(unaryOp: UnaryOp): Unit = { - accept(unaryOp.operand) - } - - override def visit(lambda: Lambda): Unit = { - push(Load) - accept(lambda.args) - pop() - accept(lambda.body) - } - - override def visit(ifExp: IfExp): Unit = { - accept(ifExp.test) - accept(ifExp.body) - accept(ifExp.orelse) - } - - override def visit(dict: Dict): Unit = { - accept(dict.keys.collect { case Some(key) => key }) - accept(dict.values) - } - - override def visit(set: ast.Set): Unit = { - accept(set.elts) - } - - override def visit(listComp: ast.ListComp): Unit = { - accept(listComp.elt) - accept(listComp.generators) - } - - override def visit(setComp: ast.SetComp): Unit = { - accept(setComp.elt) - accept(setComp.generators) - } - - override def visit(dictComp: ast.DictComp): Unit = { - accept(dictComp.key) - accept(dictComp.value) - accept(dictComp.generators) - } - - override def visit(generatorExp: ast.GeneratorExp): Unit = { - accept(generatorExp.elt) - accept(generatorExp.generators) - } - - override def visit(await: ast.Await): Unit = { - accept(await.value) - } +class MemoryOperationCalculator extends AstVisitor[Unit]: + private val stack = mutable.Stack.empty[MemoryOperation] + val astNodeToMemOp = new AstNodeToMemoryOperationMap() + val names = mutable.Set.empty[String] + + private def accept(astNode: iast): Unit = + astNode.accept(this) + + private def accept(astNodes: Iterable[iast]): Unit = + astNodes.foreach(accept) + + private def push(memOp: MemoryOperation): Unit = + stack.push(memOp) + + private def pop(): Unit = + stack.pop() + + override def visit(astNode: iast): Unit = ??? + + override def visit(mod: imod): Unit = ??? + + override def visit(module: Module): Unit = + accept(module.stmts) + + override def visit(stmt: istmt): Unit = ??? + + override def visit(functionDef: FunctionDef): Unit = + push(Load) + accept(functionDef.decorator_list) + accept(functionDef.args) + accept(functionDef.returns) + pop() + accept(functionDef.body) + + override def visit(functionDef: AsyncFunctionDef): Unit = + push(Load) + accept(functionDef.decorator_list) + accept(functionDef.args) + accept(functionDef.returns) + pop() + accept(functionDef.body) + + override def visit(classDef: ClassDef): Unit = + push(Load) + accept(classDef.decorator_list) + accept(classDef.bases) + accept(classDef.keywords) + pop() + accept(classDef.body) + + override def visit(ret: Return): Unit = + push(Load) + accept(ret.value) + pop() + + override def visit(delete: Delete): Unit = + push(Del) + accept(delete.targets) + pop() + + override def visit(assign: Assign): Unit = + push(Store) + accept(assign.targets) + pop() + push(Load) + accept(assign.value) + pop() + + override def visit(annAssign: AnnAssign): Unit = + push(Store) + accept(annAssign.target) + pop() + push(Load) + accept(annAssign.annotation) + accept(annAssign.value) + pop() + + override def visit(augAssign: AugAssign): Unit = + push(Store) + accept(augAssign.target) + pop() + push(Load) + accept(augAssign.value) + pop() + + override def visit(forStmt: For): Unit = + push(Store) + accept(forStmt.target) + pop() + push(Load) + accept(forStmt.iter) + pop() + accept(forStmt.body) + accept(forStmt.orelse) + + override def visit(forStmt: AsyncFor): Unit = + push(Store) + accept(forStmt.target) + pop() + push(Load) + accept(forStmt.iter) + pop() + accept(forStmt.body) + accept(forStmt.orelse) + + override def visit(whileStmt: While): Unit = + push(Load) + accept(whileStmt.test) + pop() + accept(whileStmt.body) + accept(whileStmt.orelse) + + override def visit(ifStmt: If): Unit = + push(Load) + accept(ifStmt.test) + pop() + accept(ifStmt.body) + accept(ifStmt.orelse) + + override def visit(withStmt: With): Unit = + accept(withStmt.items) + accept(withStmt.body) + + override def visit(withStmt: AsyncWith): Unit = + accept(withStmt.items) + accept(withStmt.body) + + override def visit(matchStmt: Match): Unit = + push(Load) + accept(matchStmt.subject) + accept(matchStmt.cases) + pop() + + override def visit(raise: Raise): Unit = + push(Load) + accept(raise.exc) + accept(raise.cause) + pop() + + override def visit(tryStmt: Try): Unit = + accept(tryStmt.body) + accept(tryStmt.handlers) + accept(tryStmt.orelse) + accept(tryStmt.finalbody) + + override def visit(assert: Assert): Unit = + push(Load) + accept(assert.test) + accept(assert.msg) + pop() + + override def visit(importStmt: Import): Unit = {} + + override def visit(importFrom: ImportFrom): Unit = {} + + override def visit(global: Global): Unit = {} + + override def visit(nonlocal: Nonlocal): Unit = {} + + override def visit(expr: Expr): Unit = + push(Load) + expr.value.accept(this) + pop() + + override def visit(pass: Pass): Unit = {} + + override def visit(break: Break): Unit = {} + + override def visit(continue: Continue): Unit = {} + + override def visit(raise: RaiseP2): Unit = + push(Load) + accept(raise.typ) + accept(raise.inst) + accept(raise.tback) + pop() + + override def visit(errorStatement: ErrorStatement): Unit = {} + + override def visit(expr: iexpr): Unit = ??? + + override def visit(boolOp: BoolOp): Unit = + accept(boolOp.values) + + override def visit(namedExpr: NamedExpr): Unit = + push(Store) + accept(namedExpr.target) + pop() + accept(namedExpr.value) + + override def visit(binOp: BinOp): Unit = + accept(binOp.left) + accept(binOp.right) + + override def visit(unaryOp: UnaryOp): Unit = + accept(unaryOp.operand) + + override def visit(lambda: Lambda): Unit = + push(Load) + accept(lambda.args) + pop() + accept(lambda.body) + + override def visit(ifExp: IfExp): Unit = + accept(ifExp.test) + accept(ifExp.body) + accept(ifExp.orelse) - override def visit(yieldExpr: ast.Yield): Unit = { - accept(yieldExpr.value) - } + override def visit(dict: Dict): Unit = + accept(dict.keys.collect { case Some(key) => key }) + accept(dict.values) - override def visit(yieldFrom: ast.YieldFrom): Unit = { - accept(yieldFrom.value) - } + override def visit(set: ast.Set): Unit = + accept(set.elts) - override def visit(compare: ast.Compare): Unit = { - accept(compare.left) - accept(compare.comparators) - } + override def visit(listComp: ast.ListComp): Unit = + accept(listComp.elt) + accept(listComp.generators) - override def visit(call: ast.Call): Unit = { - assert(stack.head == Load) - accept(call.func) - accept(call.args) - accept(call.keywords) - } + override def visit(setComp: ast.SetComp): Unit = + accept(setComp.elt) + accept(setComp.generators) - override def visit(formattedValue: FormattedValue): Unit = { - assert(stack.head == Load) - accept(formattedValue.value) - } + override def visit(dictComp: ast.DictComp): Unit = + accept(dictComp.key) + accept(dictComp.value) + accept(dictComp.generators) - override def visit(joinedString: JoinedString): Unit = { - assert(stack.head == Load) - accept(joinedString.values) - } + override def visit(generatorExp: ast.GeneratorExp): Unit = + accept(generatorExp.elt) + accept(generatorExp.generators) - override def visit(constant: ast.Constant): Unit = {} + override def visit(await: ast.Await): Unit = + accept(await.value) - override def visit(attribute: ast.Attribute): Unit = { - push(Load) - accept(attribute.value) - pop() - astNodeToMemOp.put(attribute, stack.head) - } + override def visit(yieldExpr: ast.Yield): Unit = + accept(yieldExpr.value) - override def visit(subscript: ast.Subscript): Unit = { - push(Load) - accept(subscript.value) - accept(subscript.slice) - pop() - astNodeToMemOp.put(subscript, stack.head) - } + override def visit(yieldFrom: ast.YieldFrom): Unit = + accept(yieldFrom.value) - override def visit(starred: ast.Starred): Unit = { - accept(starred.value) - astNodeToMemOp.put(starred, stack.head) - } + override def visit(compare: ast.Compare): Unit = + accept(compare.left) + accept(compare.comparators) - override def visit(name: ast.Name): Unit = { - astNodeToMemOp.put(name, stack.head) - names.add(name.id) - } + override def visit(call: ast.Call): Unit = + assert(stack.head == Load) + accept(call.func) + accept(call.args) + accept(call.keywords) - override def visit(list: ast.List): Unit = { - accept(list.elts) - astNodeToMemOp.put(list, stack.head) - } + override def visit(formattedValue: FormattedValue): Unit = + assert(stack.head == Load) + accept(formattedValue.value) - override def visit(tuple: ast.Tuple): Unit = { - accept(tuple.elts) - astNodeToMemOp.put(tuple, stack.head) - } + override def visit(joinedString: JoinedString): Unit = + assert(stack.head == Load) + accept(joinedString.values) - override def visit(slice: ast.Slice): Unit = { - push(Load) - accept(slice.lower) - accept(slice.upper) - accept(slice.step) - pop() - } + override def visit(constant: ast.Constant): Unit = {} - override def visit(stringExpList: ast.StringExpList): Unit = { - accept(stringExpList.elts) - } + override def visit(attribute: ast.Attribute): Unit = + push(Load) + accept(attribute.value) + pop() + astNodeToMemOp.put(attribute, stack.head) - override def visit(boolop: ast.iboolop): Unit = {} + override def visit(subscript: ast.Subscript): Unit = + push(Load) + accept(subscript.value) + accept(subscript.slice) + pop() + astNodeToMemOp.put(subscript, stack.head) - override def visit(and: ast.And.type): Unit = {} + override def visit(starred: ast.Starred): Unit = + accept(starred.value) + astNodeToMemOp.put(starred, stack.head) - override def visit(or: ast.Or.type): Unit = {} + override def visit(name: ast.Name): Unit = + astNodeToMemOp.put(name, stack.head) + names.add(name.id) - override def visit(operator: ast.ioperator): Unit = {} + override def visit(list: ast.List): Unit = + accept(list.elts) + astNodeToMemOp.put(list, stack.head) - override def visit(add: ast.Add.type): Unit = {} + override def visit(tuple: ast.Tuple): Unit = + accept(tuple.elts) + astNodeToMemOp.put(tuple, stack.head) - override def visit(sub: ast.Sub.type): Unit = {} + override def visit(slice: ast.Slice): Unit = + push(Load) + accept(slice.lower) + accept(slice.upper) + accept(slice.step) + pop() - override def visit(mult: ast.Mult.type): Unit = {} + override def visit(stringExpList: ast.StringExpList): Unit = + accept(stringExpList.elts) - override def visit(matMult: ast.MatMult.type): Unit = {} + override def visit(boolop: ast.iboolop): Unit = {} - override def visit(div: ast.Div.type): Unit = {} + override def visit(and: ast.And.type): Unit = {} - override def visit(mod: ast.Mod.type): Unit = {} + override def visit(or: ast.Or.type): Unit = {} - override def visit(pow: ast.Pow.type): Unit = {} + override def visit(operator: ast.ioperator): Unit = {} - override def visit(lShift: ast.LShift.type): Unit = {} + override def visit(add: ast.Add.type): Unit = {} - override def visit(rShift: ast.RShift.type): Unit = {} + override def visit(sub: ast.Sub.type): Unit = {} - override def visit(bitOr: ast.BitOr.type): Unit = {} + override def visit(mult: ast.Mult.type): Unit = {} - override def visit(bitXor: ast.BitXor.type): Unit = {} + override def visit(matMult: ast.MatMult.type): Unit = {} - override def visit(bitAnd: ast.BitAnd.type): Unit = {} + override def visit(div: ast.Div.type): Unit = {} - override def visit(floorDiv: ast.FloorDiv.type): Unit = {} + override def visit(mod: ast.Mod.type): Unit = {} - override def visit(unaryop: ast.iunaryop): Unit = {} + override def visit(pow: ast.Pow.type): Unit = {} - override def visit(invert: ast.Invert.type): Unit = {} + override def visit(lShift: ast.LShift.type): Unit = {} - override def visit(not: ast.Not.type): Unit = {} + override def visit(rShift: ast.RShift.type): Unit = {} - override def visit(uAdd: ast.UAdd.type): Unit = {} + override def visit(bitOr: ast.BitOr.type): Unit = {} - override def visit(uSub: ast.USub.type): Unit = {} + override def visit(bitXor: ast.BitXor.type): Unit = {} - override def visit(compop: ast.icompop): Unit = {} + override def visit(bitAnd: ast.BitAnd.type): Unit = {} - override def visit(eq: ast.Eq.type): Unit = {} + override def visit(floorDiv: ast.FloorDiv.type): Unit = {} - override def visit(notEq: ast.NotEq.type): Unit = {} + override def visit(unaryop: ast.iunaryop): Unit = {} - override def visit(lt: ast.Lt.type): Unit = {} + override def visit(invert: ast.Invert.type): Unit = {} - override def visit(ltE: ast.LtE.type): Unit = {} + override def visit(not: ast.Not.type): Unit = {} - override def visit(gt: ast.Gt.type): Unit = {} + override def visit(uAdd: ast.UAdd.type): Unit = {} - override def visit(gtE: ast.GtE.type): Unit = {} + override def visit(uSub: ast.USub.type): Unit = {} - override def visit(is: ast.Is.type): Unit = {} + override def visit(compop: ast.icompop): Unit = {} - override def visit(isNot: ast.IsNot.type): Unit = {} + override def visit(eq: ast.Eq.type): Unit = {} - override def visit(in: ast.In.type): Unit = {} + override def visit(notEq: ast.NotEq.type): Unit = {} - override def visit(notIn: ast.NotIn.type): Unit = {} + override def visit(lt: ast.Lt.type): Unit = {} - override def visit(comprehension: ast.Comprehension): Unit = { - assert(stack.head == Load) - push(Store) - accept(comprehension.target) - pop() - accept(comprehension.iter) - accept(comprehension.ifs) - } + override def visit(ltE: ast.LtE.type): Unit = {} - override def visit(exceptHandler: ast.ExceptHandler): Unit = { - push(Load) - accept(exceptHandler.typ) - pop() - accept(exceptHandler.body) - } + override def visit(gt: ast.Gt.type): Unit = {} - override def visit(arguments: ast.Arguments): Unit = { - accept(arguments.posonlyargs) - accept(arguments.args) - accept(arguments.vararg) - accept(arguments.kwonlyargs) - accept(arguments.kw_defaults.collect { case Some(default) => default }) - accept(arguments.kw_arg) - accept(arguments.defaults) - } + override def visit(gtE: ast.GtE.type): Unit = {} - override def visit(arg: ast.Arg): Unit = { - accept(arg.annotation) - } + override def visit(is: ast.Is.type): Unit = {} - override def visit(constant: ast.iconstant): Unit = ??? + override def visit(isNot: ast.IsNot.type): Unit = {} - override def visit(stringConstant: ast.StringConstant): Unit = {} + override def visit(in: ast.In.type): Unit = {} - override def visit(joinedStringConstant: JoinedStringConstant): Unit = {} + override def visit(notIn: ast.NotIn.type): Unit = {} - override def visit(boolConstant: ast.BoolConstant): Unit = {} + override def visit(comprehension: ast.Comprehension): Unit = + assert(stack.head == Load) + push(Store) + accept(comprehension.target) + pop() + accept(comprehension.iter) + accept(comprehension.ifs) - override def visit(intConstant: ast.IntConstant): Unit = {} + override def visit(exceptHandler: ast.ExceptHandler): Unit = + push(Load) + accept(exceptHandler.typ) + pop() + accept(exceptHandler.body) - override def visit(intConstant: ast.FloatConstant): Unit = {} + override def visit(arguments: ast.Arguments): Unit = + accept(arguments.posonlyargs) + accept(arguments.args) + accept(arguments.vararg) + accept(arguments.kwonlyargs) + accept(arguments.kw_defaults.collect { case Some(default) => default }) + accept(arguments.kw_arg) + accept(arguments.defaults) - override def visit(imaginaryConstant: ast.ImaginaryConstant): Unit = {} + override def visit(arg: ast.Arg): Unit = + accept(arg.annotation) - override def visit(noneConstant: ast.NoneConstant.type): Unit = {} + override def visit(constant: ast.iconstant): Unit = ??? - override def visit(ellipsisConstant: ast.EllipsisConstant.type): Unit = {} + override def visit(stringConstant: ast.StringConstant): Unit = {} - override def visit(keyword: ast.Keyword): Unit = { - assert(stack.head == Load) - accept(keyword.value) - } + override def visit(joinedStringConstant: JoinedStringConstant): Unit = {} - override def visit(alias: ast.Alias): Unit = {} + override def visit(boolConstant: ast.BoolConstant): Unit = {} - override def visit(withItem: ast.Withitem): Unit = { - push(Load) - accept(withItem.context_expr) - pop() - push(Store) - accept(withItem.optional_vars) - pop() - } + override def visit(intConstant: ast.IntConstant): Unit = {} - override def visit(matchCase: MatchCase): Unit = { - accept(matchCase.pattern) - accept(matchCase.guard) - accept(matchCase.body) - } + override def visit(intConstant: ast.FloatConstant): Unit = {} - override def visit(matchValue: MatchValue): Unit = { - accept(matchValue.value) - } + override def visit(imaginaryConstant: ast.ImaginaryConstant): Unit = {} - override def visit(matchSingleton: MatchSingleton): Unit = {} + override def visit(noneConstant: ast.NoneConstant.type): Unit = {} - override def visit(matchSequence: MatchSequence): Unit = { - accept(matchSequence.patterns) - } + override def visit(ellipsisConstant: ast.EllipsisConstant.type): Unit = {} - override def visit(matchMapping: MatchMapping): Unit = { - accept(matchMapping.keys) - accept(matchMapping.patterns) - } + override def visit(keyword: ast.Keyword): Unit = + assert(stack.head == Load) + accept(keyword.value) - override def visit(matchClass: MatchClass): Unit = { - accept(matchClass.cls) - accept(matchClass.patterns) - accept(matchClass.kwd_patterns) - } + override def visit(alias: ast.Alias): Unit = {} - override def visit(matchStar: MatchStar): Unit = {} + override def visit(withItem: ast.Withitem): Unit = + push(Load) + accept(withItem.context_expr) + pop() + push(Store) + accept(withItem.optional_vars) + pop() - override def visit(matchAs: MatchAs): Unit = { - accept(matchAs.pattern) - } + override def visit(matchCase: MatchCase): Unit = + accept(matchCase.pattern) + accept(matchCase.guard) + accept(matchCase.body) - override def visit(matchOr: MatchOr): Unit = { - accept(matchOr.patterns) - } + override def visit(matchValue: MatchValue): Unit = + accept(matchValue.value) - override def visit(typeIgnore: ast.TypeIgnore): Unit = {} -} + override def visit(matchSingleton: MatchSingleton): Unit = {} + + override def visit(matchSequence: MatchSequence): Unit = + accept(matchSequence.patterns) + + override def visit(matchMapping: MatchMapping): Unit = + accept(matchMapping.keys) + accept(matchMapping.patterns) + + override def visit(matchClass: MatchClass): Unit = + accept(matchClass.cls) + accept(matchClass.patterns) + accept(matchClass.kwd_patterns) + + override def visit(matchStar: MatchStar): Unit = {} + + override def visit(matchAs: MatchAs): Unit = + accept(matchAs.pattern) + + override def visit(matchOr: MatchOr): Unit = + accept(matchOr.patterns) + + override def visit(typeIgnore: ast.TypeIgnore): Unit = {} +end MemoryOperationCalculator diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala index 37e88dbe..53bbf461 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala @@ -1,879 +1,748 @@ package io.appthreat.pythonparser import io.appthreat.pythonparser.ast.{ - Add, - Alias, - And, - AnnAssign, - Arg, - Arguments, - Assert, - Assign, - AsyncFor, - AsyncFunctionDef, - AsyncWith, - Attribute, - AugAssign, - Await, - BinOp, - BitAnd, - BitOr, - BitXor, - BoolConstant, - BoolOp, - Break, - Call, - ClassDef, - Compare, - Comprehension, - Constant, - Continue, - Delete, - Dict, - DictComp, - Div, - EllipsisConstant, - Eq, - ErrorStatement, - ExceptHandler, - Expr, - FloatConstant, - FloorDiv, - For, - FormattedValue, - FunctionDef, - GeneratorExp, - Global, - Gt, - GtE, - If, - IfExp, - ImaginaryConstant, - Import, - ImportFrom, - In, - IntConstant, - Invert, - Is, - IsNot, - JoinedString, - JoinedStringConstant, - Keyword, - LShift, - Lambda, - ListComp, - Lt, - LtE, - MatMult, - Match, - MatchAs, - MatchCase, - MatchClass, - MatchMapping, - MatchOr, - MatchSequence, - MatchSingleton, - MatchStar, - MatchValue, - Mod, - Module, - Mult, - Name, - NamedExpr, - NoneConstant, - Nonlocal, - Not, - NotEq, - NotIn, - Or, - Pass, - Pow, - RShift, - Raise, - RaiseP2, - Return, - SetComp, - Slice, - Starred, - StringConstant, - StringExpList, - Sub, - Subscript, - Try, - Tuple, - TypeIgnore, - UAdd, - USub, - UnaryOp, - While, - With, - Withitem, - Yield, - YieldFrom, - iast, - iboolop, - icompop, - iconstant, - iexpr, - imod, - ioperator, - istmt, - iunaryop + Add, + Alias, + And, + AnnAssign, + Arg, + Arguments, + Assert, + Assign, + AsyncFor, + AsyncFunctionDef, + AsyncWith, + Attribute, + AugAssign, + Await, + BinOp, + BitAnd, + BitOr, + BitXor, + BoolConstant, + BoolOp, + Break, + Call, + ClassDef, + Compare, + Comprehension, + Constant, + Continue, + Delete, + Dict, + DictComp, + Div, + EllipsisConstant, + Eq, + ErrorStatement, + ExceptHandler, + Expr, + FloatConstant, + FloorDiv, + For, + FormattedValue, + FunctionDef, + GeneratorExp, + Global, + Gt, + GtE, + If, + IfExp, + ImaginaryConstant, + Import, + ImportFrom, + In, + IntConstant, + Invert, + Is, + IsNot, + JoinedString, + JoinedStringConstant, + Keyword, + LShift, + Lambda, + ListComp, + Lt, + LtE, + MatMult, + Match, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue, + Mod, + Module, + Mult, + Name, + NamedExpr, + NoneConstant, + Nonlocal, + Not, + NotEq, + NotIn, + Or, + Pass, + Pow, + RShift, + Raise, + RaiseP2, + Return, + SetComp, + Slice, + Starred, + StringConstant, + StringExpList, + Sub, + Subscript, + Try, + Tuple, + TypeIgnore, + UAdd, + USub, + UnaryOp, + While, + With, + Withitem, + Yield, + YieldFrom, + iast, + iboolop, + icompop, + iconstant, + iexpr, + imod, + ioperator, + istmt, + iunaryop } import io.appthreat.pythonparser.ast.* import scala.collection.immutable -class AstPrinter(indentStr: String) extends AstVisitor[String] { - private val ls = "\n" - - def print(astNode: iast): String = { - astNode.accept(this) - } - - def printIndented(astNode: iast): String = { - val printStr = astNode.accept(this) - - indentStr + printStr.replaceAll(ls, ls + indentStr) - } - - override def visit(ast: iast): String = ??? - - override def visit(mod: imod): String = ??? - - override def visit(module: Module): String = { - module.stmts.map(print).mkString(ls) - } - - override def visit(stmt: istmt): String = ??? - - override def visit(functionDef: FunctionDef): String = { - functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "def " + functionDef.name + "(" + print(functionDef.args) + ")" + - functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + - ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") - - } - - override def visit(functionDef: AsyncFunctionDef): String = { - functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "async def " + functionDef.name + "(" + print(functionDef.args) + ")" + - functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + - ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") - - } - - override def visit(classDef: ClassDef): String = { - val optionArgEndComma = if (classDef.bases.nonEmpty && classDef.keywords.nonEmpty) ", " else "" - - classDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "class " + classDef.name + - "(" + - classDef.bases.map(print).mkString(", ") + - optionArgEndComma + - classDef.keywords.map(print).mkString(", ") + - ")" + ":" + - classDef.body.map(printIndented).mkString(ls, ls, "") - } - - override def visit(ret: Return): String = { - "return" + ret.value.map(v => " " + print(v)).getOrElse("") - } - - override def visit(delete: Delete): String = { - "del " + delete.targets.map(print).mkString(", ") - } - - override def visit(assign: Assign): String = { - assign.targets.map(print).mkString("", " = ", " = ") + print(assign.value) - } - - override def visit(annAssign: AnnAssign): String = { - print(annAssign.target) + - ": " + print(annAssign.annotation) + - annAssign.value.map(v => " = " + print(v)).getOrElse("") - } - - override def visit(augAssign: AugAssign): String = { - print(augAssign.target) + - " " + print(augAssign.op) + "= " + - print(augAssign.value) - } - - override def visit(forStmt: For): String = { - "for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + - forStmt.body.map(printIndented).mkString(ls, ls, "") + - (if (forStmt.orelse.nonEmpty) - s"${ls}else:" + - forStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - } - - override def visit(forStmt: AsyncFor): String = { - "async for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + - forStmt.body.map(printIndented).mkString(ls, ls, "") + - (if (forStmt.orelse.nonEmpty) - s"${ls}else:" + - forStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - } - - override def visit(whileStmt: While): String = { - "while " + print(whileStmt.test) + ":" + - whileStmt.body.map(printIndented).mkString(ls, ls, "") + - (if (whileStmt.orelse.nonEmpty) - s"${ls}else:" + - whileStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - } - - override def visit(ifStmt: If): String = { - val elseString = - ifStmt.orelse.size match { - case 0 => "" - case 1 if ifStmt.orelse.head.isInstanceOf[If] => - s"${ls}el" + print(ifStmt.orelse.head) - case _ => - s"${ls}else:" + - ifStmt.orelse.map(printIndented).mkString(ls, ls, "") - } - - "if " + print(ifStmt.test) + ":" + - ifStmt.body.map(printIndented).mkString(ls, ls, "") + - elseString - } - - override def visit(withStmt: With): String = { - "with " + withStmt.items.map(print).mkString(", ") + ":" + - withStmt.body.map(printIndented).mkString(ls, ls, "") - } - - override def visit(withStmt: AsyncWith): String = { - "async with " + withStmt.items.map(print).mkString(", ") + ":" + - withStmt.body.map(printIndented).mkString(ls, ls, "") - } - - override def visit(matchStmt: Match): String = { - val subjectSuffix = matchStmt.subject match { - case _: Starred => - "," - case _ => - "" - } - "match " + print(matchStmt.subject) + subjectSuffix + ":" + - matchStmt.cases.map(printIndented).mkString(ls, ls, "") - } - - override def visit(raise: Raise): String = { - "raise" + raise.exc.map(e => " " + print(e)).getOrElse("") + - raise.cause.map(c => " from " + print(c)).getOrElse("") - } - - override def visit(tryStmt: Try): String = { - val elseString = - if (tryStmt.orelse.nonEmpty) { - s"${ls}else:" + - tryStmt.orelse.map(printIndented).mkString(ls, ls, "") - } else { - "" - } - - val finallyString = - if (tryStmt.finalbody.nonEmpty) { - s"${ls}finally:" + - tryStmt.finalbody.map(printIndented).mkString(ls, ls, "") - } else { - "" - } - - val handlersString = { - if (tryStmt.handlers.nonEmpty) { - tryStmt.handlers.map(print).mkString(ls, ls, "") - } else { - "" - } - } - - "try:" + - tryStmt.body.map(printIndented).mkString(ls, ls, "") + - handlersString + - elseString + - finallyString - } - - override def visit(assert: Assert): String = { - "assert " + print(assert.test) + assert.msg.map(m => ", " + print(m)).getOrElse("") - } - - override def visit(importStmt: Import): String = { - "import " + importStmt.names.map(print).mkString(", ") - } - - override def visit(importFrom: ImportFrom): String = { - val relativeImportDots = - if (importFrom.level != 0) { - " " + "." * importFrom.level - } else { - "" - } - "from" + relativeImportDots + importFrom.module.map(m => " " + m).getOrElse("") + - " import " + importFrom.names.map(print).mkString(", ") - } - - override def visit(global: Global): String = { - "global " + global.names.mkString(", ") - } - - override def visit(nonlocal: Nonlocal): String = { - "nonlocal " + nonlocal.names.mkString(", ") - } - - override def visit(expr: Expr): String = { - print(expr.value) - } - - override def visit(pass: Pass): String = { - "pass" - } - - override def visit(break: Break): String = { - "break" - } - - override def visit(continue: Continue): String = { - "continue" - } - - override def visit(raise: RaiseP2): String = { - "raise" + raise.typ.map(t => " " + print(t)).getOrElse("") + - raise.inst.map(i => ", " + print(i)).getOrElse("") + - raise.tback.map(t => ", " + print(t)).getOrElse("") - } - - override def visit(errorStmt: ErrorStatement): String = { - "" - } - - override def visit(expr: iexpr): String = ??? - - override def visit(boolOp: BoolOp): String = { - val opString = " " + print(boolOp.op) + " " - boolOp.values.map(print).mkString(opString) - } - - override def visit(namedExpr: NamedExpr): String = { - print(namedExpr.target) + " := " + print(namedExpr.value) - } - - override def visit(binOp: BinOp): String = { - print(binOp.left) + " " + print(binOp.op) + " " + print(binOp.right) - } - - override def visit(unaryOp: UnaryOp): String = { - val opString = unaryOp.op match { - case Not => - print(unaryOp.op) + " " - case _ => - print(unaryOp.op) - } - opString + print(unaryOp.operand) - } - - override def visit(lambda: Lambda): String = { - val argStr = print(lambda.args) - if (argStr.nonEmpty) { - "lambda " + argStr + ": " + print(lambda.body) - } else { - "lambda: " + print(lambda.body) - } - } - - override def visit(ifExp: IfExp): String = { - print(ifExp.body) + " if " + print(ifExp.test) + " else " + print(ifExp.orelse) - } - - override def visit(dict: Dict): String = { - "{" + dict.keys - .zip(dict.values) - .map { case (key, value) => - key match { - case Some(k) => - print(k) + ":" + print(value) - case None => - "**" + print(value) - } - } - .mkString(", ") + "}" - } - - override def visit(set: ast.Set): String = { - "{" + set.elts.map(print).mkString(", ") + "}" - } - - override def visit(listComp: ListComp): String = { - "[" + print(listComp.elt) + listComp.generators.map(print).mkString("") + "]" - } - - override def visit(setComp: SetComp): String = { - "{" + print(setComp.elt) + setComp.generators.map(print).mkString("") + "}" - } - - override def visit(dictComp: DictComp): String = { - "{" + print(dictComp.key) + ":" + print(dictComp.value) + - dictComp.generators.map(print).mkString("") + "}" - } - - override def visit(generatorExp: GeneratorExp): String = { - "(" + print(generatorExp.elt) + generatorExp.generators.map(print).mkString("") + ")" - } - - override def visit(await: Await): String = { - "await " + print(await.value) - } - - override def visit(yieldExpr: Yield): String = { - "yield" + yieldExpr.value.map(v => " " + print(v)).getOrElse("") - } - - override def visit(yieldFrom: YieldFrom): String = { - "yield from " + print(yieldFrom.value) - } - - override def visit(compare: Compare): String = { - print(compare.left) + compare.ops - .zip(compare.comparators) - .map { case (op, comparator) => - " " + print(op) + " " + print(comparator) - } - .mkString("") - } - - override def visit(call: Call): String = { - if (call.args.size == 1 && call.args.head.isInstanceOf[GeneratorExp]) { - // Special case in order to avoid double parenthesis since GeneratorExp adds a - // set of parenthesis on its own. - print(call.func) + print(call.args.head) - } else { - val optionArgEndComma = if (call.args.nonEmpty && call.keywords.nonEmpty) ", " else "" - print(call.func) + "(" + call.args.map(print).mkString(", ") + optionArgEndComma + - call.keywords.map(print).mkString(", ") + ")" - } - } - - override def visit(formattedValue: FormattedValue): String = { - val equalSignStr = if (formattedValue.equalSign) "=" else "" - val conversionStr = formattedValue.conversion match { - case -1 => "" - case 115 => "!s" - case 114 => "!r" - case 97 => "!a" - } - - val formatSpecStr = formattedValue.format_spec match { - case Some(formatSpec) => ":" + formatSpec - case None => "" - } - - "{" + print(formattedValue.value) + - equalSignStr + - conversionStr + - formatSpecStr + - "}" - } - - override def visit(joinedString: JoinedString): String = { - joinedString.prefix + joinedString.quote + - joinedString.values.map(print).mkString("") + joinedString.quote - } - - override def visit(constant: Constant): String = { - print(constant.value) - } - - override def visit(attribute: Attribute): String = { - print(attribute.value) + "." + attribute.attr - } - - override def visit(subscript: Subscript): String = { - print(subscript.value) + "[" + print(subscript.slice) + "]" - } - - override def visit(starred: Starred): String = { - "*" + print(starred.value) - } - - override def visit(name: Name): String = { - name.id - } - - override def visit(list: ast.List): String = { - "[" + list.elts.map(print).mkString(", ") + "]" - } - - override def visit(tuple: Tuple): String = { - if (tuple.elts.size == 1) { - "(" + print(tuple.elts.head) + ",)" - } else { - "(" + tuple.elts.map(print).mkString(",") + ")" - } - } - - override def visit(slice: Slice): String = { - slice.lower.map(print).getOrElse("") + - ":" + slice.upper.map(print).getOrElse("") + - slice.step.map(expr => ":" + print(expr)).getOrElse("") - } - - override def visit(stringExpList: StringExpList): String = { - stringExpList.elts.map(print).mkString(" ") - } - - override def visit(alias: Alias): String = { - alias.name + alias.asName.map(n => " as " + n).getOrElse("") - } - - override def visit(boolop: iboolop): String = ??? - - override def visit(and: And.type): String = { - "and" - } - - override def visit(or: Or.type): String = { - "or" - } - - override def visit(compop: icompop): String = ??? - - override def visit(eq: Eq.type): String = { - "==" - } - - override def visit(noteq: NotEq.type): String = { - "!=" - } - - override def visit(lt: Lt.type): String = { - "<" - } - - override def visit(ltE: LtE.type): String = { - "<=" - } - - override def visit(gt: Gt.type): String = { - ">" - } - - override def visit(gtE: GtE.type): String = { - ">=" - } - - override def visit(is: Is.type): String = { - "is" - } - - override def visit(isNot: IsNot.type): String = { - "is not" - } - - override def visit(in: In.type): String = { - "in" - } - - override def visit(notIn: NotIn.type): String = { - "not in" - } - - override def visit(constant: iconstant): String = ??? - - override def visit(stringConstant: StringConstant): String = { - stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote - } - - override def visit(joinedStringConstant: JoinedStringConstant): String = { - joinedStringConstant.value - } - - override def visit(boolConstant: BoolConstant): String = { - if (boolConstant.value) { - "True" - } else { - "False" - } - } - - override def visit(intConstant: IntConstant): String = { - intConstant.value - } - - override def visit(floatConstant: FloatConstant): String = { - floatConstant.value - } - - override def visit(imaginaryConstant: ImaginaryConstant): String = { - imaginaryConstant.value - } - - override def visit(noneConstant: NoneConstant.type): String = { - "None" - } - - override def visit(ellipsisConstant: EllipsisConstant.type): String = { - "..." - } - - override def visit(exceptHandler: ExceptHandler): String = { - "except" + - exceptHandler.typ.map(t => " " + print(t)).getOrElse("") + - exceptHandler.name.map(n => " as " + n).getOrElse("") + - ":" + - exceptHandler.body.map(printIndented).mkString(ls, ls, "") - } - - override def visit(keyword: Keyword): String = { - keyword.arg match { - case Some(argName) => - argName + " = " + print(keyword.value) - case None => - "**" + print(keyword.value) - } - } - - override def visit(operator: ioperator): String = ??? - - override def visit(add: Add.type): String = { - "+" - } - - override def visit(sub: Sub.type): String = { - "-" - } - - override def visit(mult: Mult.type): String = { - "*" - } - - override def visit(matMult: MatMult.type): String = { - "@" - } - - override def visit(div: Div.type): String = { - "/" - } - - override def visit(mod: Mod.type): String = { - "%" - } - - override def visit(pow: Pow.type): String = { - "**" - } - - override def visit(lShift: LShift.type): String = { - "<<" - } - - override def visit(rShift: RShift.type): String = { - ">>" - } - - override def visit(bitOr: BitOr.type): String = { - "|" - } - - override def visit(bitXor: BitXor.type): String = { - "^" - } - - override def visit(bitAnd: BitAnd.type): String = { - "&" - } - - override def visit(floorDiv: FloorDiv.type): String = { - "//" - } - - override def visit(unaryop: iunaryop): String = ??? - - override def visit(invert: Invert.type): String = { - "~" - } - - override def visit(not: Not.type): String = { - "not" - } - - override def visit(uAdd: UAdd.type): String = { - "+" - } - - override def visit(uSub: USub.type): String = { - "-" - } - - override def visit(arg: Arg): String = { - arg.arg + arg.annotation.map(a => ": " + print(a)).getOrElse("") - } - - override def visit(arguments: Arguments): String = { - var result = "" - var separatorString = "" - val combinedPosArgSize = arguments.posonlyargs.size + arguments.args.size - val defaultArgs = immutable.List.fill(combinedPosArgSize - arguments.defaults.size)(None) ++ - arguments.defaults.map(Option.apply) - - if (arguments.posonlyargs.nonEmpty) { - val posOnlyArgsString = - arguments.posonlyargs - .zip(defaultArgs) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString("", ", ", ", /") - - result += posOnlyArgsString - separatorString = ", " - } - - if (arguments.args.nonEmpty) { - val defaultsForArgs = defaultArgs.drop(arguments.posonlyargs.size) - val argsString = - arguments.args - .zip(defaultsForArgs) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString(separatorString, ", ", "") - - result += argsString - separatorString = ", " - } - - arguments.vararg match { - case Some(v) => - result += separatorString - result += "*" + print(v) - separatorString = ", " - case None if arguments.kwonlyargs.nonEmpty => - result += separatorString - result += "*" - separatorString = ", " - case None => - } - - if (arguments.kwonlyargs.nonEmpty) { - val kwOnlyArgsString = - arguments.kwonlyargs - .zip(arguments.kw_defaults) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString(separatorString, ", ", "") - - result += kwOnlyArgsString - separatorString = ", " - } - - arguments.kw_arg.foreach { k => - result += separatorString - result += "**" + print(k) - } - - result - } - - override def visit(withItem: Withitem): String = { - print(withItem.context_expr) + withItem.optional_vars.map(o => " as " + print(o)).getOrElse("") - } - - override def visit(matchCase: MatchCase): String = { - "case " + print(matchCase.pattern) + matchCase.guard.map(g => " if " + print(g)).getOrElse("") + ":" + - matchCase.body.map(printIndented).mkString(ls, ls, "") - } - - override def visit(matchValue: MatchValue): String = { - print(matchValue.value) - } - - override def visit(matchSingleton: MatchSingleton): String = { - print(matchSingleton.value) - } - - override def visit(matchSequence: MatchSequence): String = { - matchSequence.patterns.map(print).mkString("[", ", ", "]") - } - - override def visit(matchMapping: MatchMapping): String = { - "{" + matchMapping.keys - .zip(matchMapping.patterns) - .map { case (key, pattern) => - print(key) + ": " + print(pattern) - } - .mkString(", ") + - matchMapping.rest - .map { r => - val separatorString = - if (matchMapping.keys.nonEmpty) { - ", " - } else { - "" +class AstPrinter(indentStr: String) extends AstVisitor[String]: + private val ls = "\n" + + def print(astNode: iast): String = + astNode.accept(this) + + def printIndented(astNode: iast): String = + val printStr = astNode.accept(this) + + indentStr + printStr.replaceAll(ls, ls + indentStr) + + override def visit(ast: iast): String = ??? + + override def visit(mod: imod): String = ??? + + override def visit(module: Module): String = + module.stmts.map(print).mkString(ls) + + override def visit(stmt: istmt): String = ??? + + override def visit(functionDef: FunctionDef): String = + functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "def " + functionDef.name + "(" + print(functionDef.args) + ")" + + functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + + ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(functionDef: AsyncFunctionDef): String = + functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "async def " + functionDef.name + "(" + print(functionDef.args) + ")" + + functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + + ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(classDef: ClassDef): String = + val optionArgEndComma = + if classDef.bases.nonEmpty && classDef.keywords.nonEmpty then ", " else "" + + classDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "class " + classDef.name + + "(" + + classDef.bases.map(print).mkString(", ") + + optionArgEndComma + + classDef.keywords.map(print).mkString(", ") + + ")" + ":" + + classDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(ret: Return): String = + "return" + ret.value.map(v => " " + print(v)).getOrElse("") + + override def visit(delete: Delete): String = + "del " + delete.targets.map(print).mkString(", ") + + override def visit(assign: Assign): String = + assign.targets.map(print).mkString("", " = ", " = ") + print(assign.value) + + override def visit(annAssign: AnnAssign): String = + print(annAssign.target) + + ": " + print(annAssign.annotation) + + annAssign.value.map(v => " = " + print(v)).getOrElse("") + + override def visit(augAssign: AugAssign): String = + print(augAssign.target) + + " " + print(augAssign.op) + "= " + + print(augAssign.value) + + override def visit(forStmt: For): String = + "for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + + forStmt.body.map(printIndented).mkString(ls, ls, "") + + (if forStmt.orelse.nonEmpty then + s"${ls}else:" + + forStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(forStmt: AsyncFor): String = + "async for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + + forStmt.body.map(printIndented).mkString(ls, ls, "") + + (if forStmt.orelse.nonEmpty then + s"${ls}else:" + + forStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(whileStmt: While): String = + "while " + print(whileStmt.test) + ":" + + whileStmt.body.map(printIndented).mkString(ls, ls, "") + + (if whileStmt.orelse.nonEmpty then + s"${ls}else:" + + whileStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(ifStmt: If): String = + val elseString = + ifStmt.orelse.size match + case 0 => "" + case 1 if ifStmt.orelse.head.isInstanceOf[If] => + s"${ls}el" + print(ifStmt.orelse.head) + case _ => + s"${ls}else:" + + ifStmt.orelse.map(printIndented).mkString(ls, ls, "") + + "if " + print(ifStmt.test) + ":" + + ifStmt.body.map(printIndented).mkString(ls, ls, "") + + elseString + + override def visit(withStmt: With): String = + "with " + withStmt.items.map(print).mkString(", ") + ":" + + withStmt.body.map(printIndented).mkString(ls, ls, "") + + override def visit(withStmt: AsyncWith): String = + "async with " + withStmt.items.map(print).mkString(", ") + ":" + + withStmt.body.map(printIndented).mkString(ls, ls, "") + + override def visit(matchStmt: Match): String = + val subjectSuffix = matchStmt.subject match + case _: Starred => + "," + case _ => + "" + "match " + print(matchStmt.subject) + subjectSuffix + ":" + + matchStmt.cases.map(printIndented).mkString(ls, ls, "") + + override def visit(raise: Raise): String = + "raise" + raise.exc.map(e => " " + print(e)).getOrElse("") + + raise.cause.map(c => " from " + print(c)).getOrElse("") + + override def visit(tryStmt: Try): String = + val elseString = + if tryStmt.orelse.nonEmpty then + s"${ls}else:" + + tryStmt.orelse.map(printIndented).mkString(ls, ls, "") + else + "" + + val finallyString = + if tryStmt.finalbody.nonEmpty then + s"${ls}finally:" + + tryStmt.finalbody.map(printIndented).mkString(ls, ls, "") + else + "" + + val handlersString = + if tryStmt.handlers.nonEmpty then + tryStmt.handlers.map(print).mkString(ls, ls, "") + else + "" + + "try:" + + tryStmt.body.map(printIndented).mkString(ls, ls, "") + + handlersString + + elseString + + finallyString + end visit + + override def visit(assert: Assert): String = + "assert " + print(assert.test) + assert.msg.map(m => ", " + print(m)).getOrElse("") + + override def visit(importStmt: Import): String = + "import " + importStmt.names.map(print).mkString(", ") + + override def visit(importFrom: ImportFrom): String = + val relativeImportDots = + if importFrom.level != 0 then + " " + "." * importFrom.level + else + "" + "from" + relativeImportDots + importFrom.module.map(m => " " + m).getOrElse("") + + " import " + importFrom.names.map(print).mkString(", ") + + override def visit(global: Global): String = + "global " + global.names.mkString(", ") + + override def visit(nonlocal: Nonlocal): String = + "nonlocal " + nonlocal.names.mkString(", ") + + override def visit(expr: Expr): String = + print(expr.value) + + override def visit(pass: Pass): String = + "pass" + + override def visit(break: Break): String = + "break" + + override def visit(continue: Continue): String = + "continue" + + override def visit(raise: RaiseP2): String = + "raise" + raise.typ.map(t => " " + print(t)).getOrElse("") + + raise.inst.map(i => ", " + print(i)).getOrElse("") + + raise.tback.map(t => ", " + print(t)).getOrElse("") + + override def visit(errorStmt: ErrorStatement): String = + "" + + override def visit(expr: iexpr): String = ??? + + override def visit(boolOp: BoolOp): String = + val opString = " " + print(boolOp.op) + " " + boolOp.values.map(print).mkString(opString) + + override def visit(namedExpr: NamedExpr): String = + print(namedExpr.target) + " := " + print(namedExpr.value) + + override def visit(binOp: BinOp): String = + print(binOp.left) + " " + print(binOp.op) + " " + print(binOp.right) + + override def visit(unaryOp: UnaryOp): String = + val opString = unaryOp.op match + case Not => + print(unaryOp.op) + " " + case _ => + print(unaryOp.op) + opString + print(unaryOp.operand) + + override def visit(lambda: Lambda): String = + val argStr = print(lambda.args) + if argStr.nonEmpty then + "lambda " + argStr + ": " + print(lambda.body) + else + "lambda: " + print(lambda.body) + + override def visit(ifExp: IfExp): String = + print(ifExp.body) + " if " + print(ifExp.test) + " else " + print(ifExp.orelse) + + override def visit(dict: Dict): String = + "{" + dict.keys + .zip(dict.values) + .map { case (key, value) => + key match + case Some(k) => + print(k) + ":" + print(value) + case None => + "**" + print(value) } - separatorString + "**" + r - } - .getOrElse("") + "}" - } - - override def visit(matchClass: MatchClass): String = { - val separatorString = - if (matchClass.patterns.nonEmpty && matchClass.kwd_patterns.nonEmpty) { - ", " - } else { - "" - } - - print(matchClass.cls) + - "(" + - matchClass.patterns.map(print).mkString(", ") + - separatorString + - matchClass.kwd_attrs - .zip(matchClass.kwd_patterns) - .map { case (name, pattern) => - name + " = " + print(pattern) + .mkString(", ") + "}" + + override def visit(set: ast.Set): String = + "{" + set.elts.map(print).mkString(", ") + "}" + + override def visit(listComp: ListComp): String = + "[" + print(listComp.elt) + listComp.generators.map(print).mkString("") + "]" + + override def visit(setComp: SetComp): String = + "{" + print(setComp.elt) + setComp.generators.map(print).mkString("") + "}" + + override def visit(dictComp: DictComp): String = + "{" + print(dictComp.key) + ":" + print(dictComp.value) + + dictComp.generators.map(print).mkString("") + "}" + + override def visit(generatorExp: GeneratorExp): String = + "(" + print(generatorExp.elt) + generatorExp.generators.map(print).mkString("") + ")" + + override def visit(await: Await): String = + "await " + print(await.value) + + override def visit(yieldExpr: Yield): String = + "yield" + yieldExpr.value.map(v => " " + print(v)).getOrElse("") + + override def visit(yieldFrom: YieldFrom): String = + "yield from " + print(yieldFrom.value) + + override def visit(compare: Compare): String = + print(compare.left) + compare.ops + .zip(compare.comparators) + .map { case (op, comparator) => + " " + print(op) + " " + print(comparator) + } + .mkString("") + + override def visit(call: Call): String = + if call.args.size == 1 && call.args.head.isInstanceOf[GeneratorExp] then + // Special case in order to avoid double parenthesis since GeneratorExp adds a + // set of parenthesis on its own. + print(call.func) + print(call.args.head) + else + val optionArgEndComma = + if call.args.nonEmpty && call.keywords.nonEmpty then ", " else "" + print(call.func) + "(" + call.args.map(print).mkString(", ") + optionArgEndComma + + call.keywords.map(print).mkString(", ") + ")" + + override def visit(formattedValue: FormattedValue): String = + val equalSignStr = if formattedValue.equalSign then "=" else "" + val conversionStr = formattedValue.conversion match + case -1 => "" + case 115 => "!s" + case 114 => "!r" + case 97 => "!a" + + val formatSpecStr = formattedValue.format_spec match + case Some(formatSpec) => ":" + formatSpec + case None => "" + + "{" + print(formattedValue.value) + + equalSignStr + + conversionStr + + formatSpecStr + + "}" + + override def visit(joinedString: JoinedString): String = + joinedString.prefix + joinedString.quote + + joinedString.values.map(print).mkString("") + joinedString.quote + + override def visit(constant: Constant): String = + print(constant.value) + + override def visit(attribute: Attribute): String = + print(attribute.value) + "." + attribute.attr + + override def visit(subscript: Subscript): String = + print(subscript.value) + "[" + print(subscript.slice) + "]" + + override def visit(starred: Starred): String = + "*" + print(starred.value) + + override def visit(name: Name): String = + name.id + + override def visit(list: ast.List): String = + "[" + list.elts.map(print).mkString(", ") + "]" + + override def visit(tuple: Tuple): String = + if tuple.elts.size == 1 then + "(" + print(tuple.elts.head) + ",)" + else + "(" + tuple.elts.map(print).mkString(",") + ")" + + override def visit(slice: Slice): String = + slice.lower.map(print).getOrElse("") + + ":" + slice.upper.map(print).getOrElse("") + + slice.step.map(expr => ":" + print(expr)).getOrElse("") + + override def visit(stringExpList: StringExpList): String = + stringExpList.elts.map(print).mkString(" ") + + override def visit(alias: Alias): String = + alias.name + alias.asName.map(n => " as " + n).getOrElse("") + + override def visit(boolop: iboolop): String = ??? + + override def visit(and: And.type): String = + "and" + + override def visit(or: Or.type): String = + "or" + + override def visit(compop: icompop): String = ??? + + override def visit(eq: Eq.type): String = + "==" + + override def visit(noteq: NotEq.type): String = + "!=" + + override def visit(lt: Lt.type): String = + "<" + + override def visit(ltE: LtE.type): String = + "<=" + + override def visit(gt: Gt.type): String = + ">" + + override def visit(gtE: GtE.type): String = + ">=" + + override def visit(is: Is.type): String = + "is" + + override def visit(isNot: IsNot.type): String = + "is not" + + override def visit(in: In.type): String = + "in" + + override def visit(notIn: NotIn.type): String = + "not in" + + override def visit(constant: iconstant): String = ??? + + override def visit(stringConstant: StringConstant): String = + stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote + + override def visit(joinedStringConstant: JoinedStringConstant): String = + joinedStringConstant.value + + override def visit(boolConstant: BoolConstant): String = + if boolConstant.value then + "True" + else + "False" + + override def visit(intConstant: IntConstant): String = + intConstant.value + + override def visit(floatConstant: FloatConstant): String = + floatConstant.value + + override def visit(imaginaryConstant: ImaginaryConstant): String = + imaginaryConstant.value + + override def visit(noneConstant: NoneConstant.type): String = + "None" + + override def visit(ellipsisConstant: EllipsisConstant.type): String = + "..." + + override def visit(exceptHandler: ExceptHandler): String = + "except" + + exceptHandler.typ.map(t => " " + print(t)).getOrElse("") + + exceptHandler.name.map(n => " as " + n).getOrElse("") + + ":" + + exceptHandler.body.map(printIndented).mkString(ls, ls, "") + + override def visit(keyword: Keyword): String = + keyword.arg match + case Some(argName) => + argName + " = " + print(keyword.value) + case None => + "**" + print(keyword.value) + + override def visit(operator: ioperator): String = ??? + + override def visit(add: Add.type): String = + "+" + + override def visit(sub: Sub.type): String = + "-" + + override def visit(mult: Mult.type): String = + "*" + + override def visit(matMult: MatMult.type): String = + "@" + + override def visit(div: Div.type): String = + "/" + + override def visit(mod: Mod.type): String = + "%" + + override def visit(pow: Pow.type): String = + "**" + + override def visit(lShift: LShift.type): String = + "<<" + + override def visit(rShift: RShift.type): String = + ">>" + + override def visit(bitOr: BitOr.type): String = + "|" + + override def visit(bitXor: BitXor.type): String = + "^" + + override def visit(bitAnd: BitAnd.type): String = + "&" + + override def visit(floorDiv: FloorDiv.type): String = + "//" + + override def visit(unaryop: iunaryop): String = ??? + + override def visit(invert: Invert.type): String = + "~" + + override def visit(not: Not.type): String = + "not" + + override def visit(uAdd: UAdd.type): String = + "+" + + override def visit(uSub: USub.type): String = + "-" + + override def visit(arg: Arg): String = + arg.arg + arg.annotation.map(a => ": " + print(a)).getOrElse("") + + override def visit(arguments: Arguments): String = + var result = "" + var separatorString = "" + val combinedPosArgSize = arguments.posonlyargs.size + arguments.args.size + val defaultArgs = immutable.List.fill(combinedPosArgSize - arguments.defaults.size)(None) ++ + arguments.defaults.map(Option.apply) + + if arguments.posonlyargs.nonEmpty then + val posOnlyArgsString = + arguments.posonlyargs + .zip(defaultArgs) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString("", ", ", ", /") + + result += posOnlyArgsString + separatorString = ", " + + if arguments.args.nonEmpty then + val defaultsForArgs = defaultArgs.drop(arguments.posonlyargs.size) + val argsString = + arguments.args + .zip(defaultsForArgs) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString(separatorString, ", ", "") + + result += argsString + separatorString = ", " + + arguments.vararg match + case Some(v) => + result += separatorString + result += "*" + print(v) + separatorString = ", " + case None if arguments.kwonlyargs.nonEmpty => + result += separatorString + result += "*" + separatorString = ", " + case None => + + if arguments.kwonlyargs.nonEmpty then + val kwOnlyArgsString = + arguments.kwonlyargs + .zip(arguments.kw_defaults) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString(separatorString, ", ", "") + + result += kwOnlyArgsString + separatorString = ", " + + arguments.kw_arg.foreach { k => + result += separatorString + result += "**" + print(k) } - .mkString(", ") + - ")" - } - - override def visit(matchStar: MatchStar): String = { - "*" + matchStar.name.getOrElse("_") - } - - override def visit(matchAs: MatchAs): String = { - matchAs.pattern match { - case Some(pattern) => - print(pattern) + matchAs.name.map(name => " as " + name).getOrElse("") - case None => - matchAs.name.getOrElse("_") - } - } - - override def visit(matchOr: MatchOr): String = { - matchOr.patterns.map(print).mkString(" | ") - } - - override def visit(comprehension: Comprehension): String = { - val prefix = - if (comprehension.is_async) { - " async for " - } else { - " for " - } - - prefix + print(comprehension.target) + " in " + print(comprehension.iter) + - comprehension.ifs.map(i => " if " + print(i)).mkString("") - } - - override def visit(typeIgnore: TypeIgnore): String = { - typeIgnore.tag - } -} + + result + end visit + + override def visit(withItem: Withitem): String = + print(withItem.context_expr) + withItem.optional_vars.map(o => " as " + print(o)).getOrElse( + "" + ) + + override def visit(matchCase: MatchCase): String = + "case " + print(matchCase.pattern) + matchCase.guard.map(g => " if " + print(g)).getOrElse( + "" + ) + ":" + + matchCase.body.map(printIndented).mkString(ls, ls, "") + + override def visit(matchValue: MatchValue): String = + print(matchValue.value) + + override def visit(matchSingleton: MatchSingleton): String = + print(matchSingleton.value) + + override def visit(matchSequence: MatchSequence): String = + matchSequence.patterns.map(print).mkString("[", ", ", "]") + + override def visit(matchMapping: MatchMapping): String = + "{" + matchMapping.keys + .zip(matchMapping.patterns) + .map { case (key, pattern) => + print(key) + ": " + print(pattern) + } + .mkString(", ") + + matchMapping.rest + .map { r => + val separatorString = + if matchMapping.keys.nonEmpty then + ", " + else + "" + separatorString + "**" + r + } + .getOrElse("") + "}" + + override def visit(matchClass: MatchClass): String = + val separatorString = + if matchClass.patterns.nonEmpty && matchClass.kwd_patterns.nonEmpty then + ", " + else + "" + + print(matchClass.cls) + + "(" + + matchClass.patterns.map(print).mkString(", ") + + separatorString + + matchClass.kwd_attrs + .zip(matchClass.kwd_patterns) + .map { case (name, pattern) => + name + " = " + print(pattern) + } + .mkString(", ") + + ")" + + override def visit(matchStar: MatchStar): String = + "*" + matchStar.name.getOrElse("_") + + override def visit(matchAs: MatchAs): String = + matchAs.pattern match + case Some(pattern) => + print(pattern) + matchAs.name.map(name => " as " + name).getOrElse("") + case None => + matchAs.name.getOrElse("_") + + override def visit(matchOr: MatchOr): String = + matchOr.patterns.map(print).mkString(" | ") + + override def visit(comprehension: Comprehension): String = + val prefix = + if comprehension.is_async then + " async for " + else + " for " + + prefix + print(comprehension.target) + " in " + print(comprehension.iter) + + comprehension.ifs.map(i => " if " + print(i)).mkString("") + + override def visit(typeIgnore: TypeIgnore): String = + typeIgnore.tag +end AstPrinter diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala index 15bc84af..0b89902a 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala @@ -1,271 +1,271 @@ package io.appthreat.pythonparser import io.appthreat.pythonparser.ast.{ - Add, - Alias, - And, - AnnAssign, - Arg, - Arguments, - Assert, - Assign, - AsyncFor, - AsyncFunctionDef, - AsyncWith, - Attribute, - AugAssign, - Await, - BinOp, - BitAnd, - BitOr, - BitXor, - BoolConstant, - BoolOp, - Break, - Call, - ClassDef, - Compare, - Comprehension, - Constant, - Continue, - Delete, - Dict, - DictComp, - Div, - EllipsisConstant, - Eq, - ErrorStatement, - ExceptHandler, - Expr, - FloatConstant, - FloorDiv, - For, - FormattedValue, - FunctionDef, - GeneratorExp, - Global, - Gt, - GtE, - If, - IfExp, - ImaginaryConstant, - Import, - ImportFrom, - In, - IntConstant, - Invert, - Is, - IsNot, - JoinedString, - JoinedStringConstant, - Keyword, - LShift, - Lambda, - ListComp, - Lt, - LtE, - MatMult, - Match, - MatchAs, - MatchCase, - MatchClass, - MatchMapping, - MatchOr, - MatchSequence, - MatchSingleton, - MatchStar, - MatchValue, - Mod, - Module, - Mult, - Name, - NamedExpr, - NoneConstant, - Nonlocal, - Not, - NotEq, - NotIn, - Or, - Pass, - Pow, - RShift, - Raise, - RaiseP2, - Return, - SetComp, - Slice, - Starred, - StringConstant, - StringExpList, - Sub, - Subscript, - Try, - Tuple, - TypeIgnore, - UAdd, - USub, - UnaryOp, - While, - With, - Withitem, - Yield, - YieldFrom, - iast, - iboolop, - icompop, - iconstant, - iexpr, - imod, - ioperator, - istmt, - iunaryop + Add, + Alias, + And, + AnnAssign, + Arg, + Arguments, + Assert, + Assign, + AsyncFor, + AsyncFunctionDef, + AsyncWith, + Attribute, + AugAssign, + Await, + BinOp, + BitAnd, + BitOr, + BitXor, + BoolConstant, + BoolOp, + Break, + Call, + ClassDef, + Compare, + Comprehension, + Constant, + Continue, + Delete, + Dict, + DictComp, + Div, + EllipsisConstant, + Eq, + ErrorStatement, + ExceptHandler, + Expr, + FloatConstant, + FloorDiv, + For, + FormattedValue, + FunctionDef, + GeneratorExp, + Global, + Gt, + GtE, + If, + IfExp, + ImaginaryConstant, + Import, + ImportFrom, + In, + IntConstant, + Invert, + Is, + IsNot, + JoinedString, + JoinedStringConstant, + Keyword, + LShift, + Lambda, + ListComp, + Lt, + LtE, + MatMult, + Match, + MatchAs, + MatchCase, + MatchClass, + MatchMapping, + MatchOr, + MatchSequence, + MatchSingleton, + MatchStar, + MatchValue, + Mod, + Module, + Mult, + Name, + NamedExpr, + NoneConstant, + Nonlocal, + Not, + NotEq, + NotIn, + Or, + Pass, + Pow, + RShift, + Raise, + RaiseP2, + Return, + SetComp, + Slice, + Starred, + StringConstant, + StringExpList, + Sub, + Subscript, + Try, + Tuple, + TypeIgnore, + UAdd, + USub, + UnaryOp, + While, + With, + Withitem, + Yield, + YieldFrom, + iast, + iboolop, + icompop, + iconstant, + iexpr, + imod, + ioperator, + istmt, + iunaryop } import io.appthreat.pythonparser.ast.* -trait AstVisitor[T] { - def visit(ast: iast): T +trait AstVisitor[T]: + def visit(ast: iast): T - def visit(mod: imod): T - def visit(module: Module): T + def visit(mod: imod): T + def visit(module: Module): T - def visit(stmt: istmt): T - def visit(functionDef: FunctionDef): T - def visit(functionDef: AsyncFunctionDef): T - def visit(classDef: ClassDef): T - def visit(ret: Return): T - def visit(delete: Delete): T - def visit(assign: Assign): T - def visit(annAssign: AnnAssign): T - def visit(augAssign: AugAssign): T - def visit(forStmt: For): T - def visit(forStmt: AsyncFor): T - def visit(whileStmt: While): T - def visit(ifStmt: If): T - def visit(withStmt: With): T - def visit(withStmt: AsyncWith): T - def visit(matchStmt: Match): T - def visit(raise: Raise): T - def visit(tryStmt: Try): T - def visit(assert: Assert): T - def visit(importStmt: Import): T - def visit(importFrom: ImportFrom): T - def visit(global: Global): T - def visit(nonlocal: Nonlocal): T - def visit(expr: Expr): T - def visit(pass: Pass): T - def visit(break: Break): T - def visit(continue: Continue): T - def visit(raise: RaiseP2): T - def visit(errorStatement: ErrorStatement): T + def visit(stmt: istmt): T + def visit(functionDef: FunctionDef): T + def visit(functionDef: AsyncFunctionDef): T + def visit(classDef: ClassDef): T + def visit(ret: Return): T + def visit(delete: Delete): T + def visit(assign: Assign): T + def visit(annAssign: AnnAssign): T + def visit(augAssign: AugAssign): T + def visit(forStmt: For): T + def visit(forStmt: AsyncFor): T + def visit(whileStmt: While): T + def visit(ifStmt: If): T + def visit(withStmt: With): T + def visit(withStmt: AsyncWith): T + def visit(matchStmt: Match): T + def visit(raise: Raise): T + def visit(tryStmt: Try): T + def visit(assert: Assert): T + def visit(importStmt: Import): T + def visit(importFrom: ImportFrom): T + def visit(global: Global): T + def visit(nonlocal: Nonlocal): T + def visit(expr: Expr): T + def visit(pass: Pass): T + def visit(break: Break): T + def visit(continue: Continue): T + def visit(raise: RaiseP2): T + def visit(errorStatement: ErrorStatement): T - def visit(expr: iexpr): T - def visit(boolOp: BoolOp): T - def visit(namedExpr: NamedExpr): T - def visit(binOp: BinOp): T - def visit(unaryOp: UnaryOp): T - def visit(lambda: Lambda): T - def visit(ifExp: IfExp): T - def visit(dict: Dict): T - def visit(set: ast.Set): T - def visit(listComp: ListComp): T - def visit(setComp: SetComp): T - def visit(dictComp: DictComp): T - def visit(generatorExp: GeneratorExp): T - def visit(await: Await): T - def visit(yieldExpr: Yield): T - def visit(yieldFrom: YieldFrom): T - def visit(compare: Compare): T - def visit(call: Call): T - def visit(formattedValue: FormattedValue): T - def visit(joinedString: JoinedString): T - def visit(constant: Constant): T - def visit(attribute: Attribute): T - def visit(subscript: Subscript): T - def visit(starred: Starred): T - def visit(name: Name): T - def visit(list: ast.List): T - def visit(tuple: Tuple): T - def visit(slice: Slice): T - def visit(stringExpList: StringExpList): T + def visit(expr: iexpr): T + def visit(boolOp: BoolOp): T + def visit(namedExpr: NamedExpr): T + def visit(binOp: BinOp): T + def visit(unaryOp: UnaryOp): T + def visit(lambda: Lambda): T + def visit(ifExp: IfExp): T + def visit(dict: Dict): T + def visit(set: ast.Set): T + def visit(listComp: ListComp): T + def visit(setComp: SetComp): T + def visit(dictComp: DictComp): T + def visit(generatorExp: GeneratorExp): T + def visit(await: Await): T + def visit(yieldExpr: Yield): T + def visit(yieldFrom: YieldFrom): T + def visit(compare: Compare): T + def visit(call: Call): T + def visit(formattedValue: FormattedValue): T + def visit(joinedString: JoinedString): T + def visit(constant: Constant): T + def visit(attribute: Attribute): T + def visit(subscript: Subscript): T + def visit(starred: Starred): T + def visit(name: Name): T + def visit(list: ast.List): T + def visit(tuple: Tuple): T + def visit(slice: Slice): T + def visit(stringExpList: StringExpList): T - def visit(boolop: iboolop): T - def visit(and: And.type): T - def visit(or: Or.type): T + def visit(boolop: iboolop): T + def visit(and: And.type): T + def visit(or: Or.type): T - def visit(operator: ioperator): T - def visit(add: Add.type): T - def visit(sub: Sub.type): T - def visit(mult: Mult.type): T - def visit(matMult: MatMult.type): T - def visit(div: Div.type): T - def visit(mod: Mod.type): T - def visit(pow: Pow.type): T - def visit(lShift: LShift.type): T - def visit(rShift: RShift.type): T - def visit(bitOr: BitOr.type): T - def visit(bitXor: BitXor.type): T - def visit(bitAnd: BitAnd.type): T - def visit(floorDiv: FloorDiv.type): T + def visit(operator: ioperator): T + def visit(add: Add.type): T + def visit(sub: Sub.type): T + def visit(mult: Mult.type): T + def visit(matMult: MatMult.type): T + def visit(div: Div.type): T + def visit(mod: Mod.type): T + def visit(pow: Pow.type): T + def visit(lShift: LShift.type): T + def visit(rShift: RShift.type): T + def visit(bitOr: BitOr.type): T + def visit(bitXor: BitXor.type): T + def visit(bitAnd: BitAnd.type): T + def visit(floorDiv: FloorDiv.type): T - def visit(unaryop: iunaryop): T - def visit(invert: Invert.type): T - def visit(not: Not.type): T - def visit(uAdd: UAdd.type): T - def visit(uSub: USub.type): T + def visit(unaryop: iunaryop): T + def visit(invert: Invert.type): T + def visit(not: Not.type): T + def visit(uAdd: UAdd.type): T + def visit(uSub: USub.type): T - def visit(compop: icompop): T - def visit(eq: Eq.type): T - def visit(notEq: NotEq.type): T - def visit(lt: Lt.type): T - def visit(ltE: LtE.type): T - def visit(gt: Gt.type): T - def visit(gtE: GtE.type): T - def visit(is: Is.type): T - def visit(isNot: IsNot.type): T - def visit(in: In.type): T - def visit(notIn: NotIn.type): T + def visit(compop: icompop): T + def visit(eq: Eq.type): T + def visit(notEq: NotEq.type): T + def visit(lt: Lt.type): T + def visit(ltE: LtE.type): T + def visit(gt: Gt.type): T + def visit(gtE: GtE.type): T + def visit(is: Is.type): T + def visit(isNot: IsNot.type): T + def visit(in: In.type): T + def visit(notIn: NotIn.type): T - def visit(comprehension: Comprehension): T + def visit(comprehension: Comprehension): T - def visit(exceptHandler: ExceptHandler): T + def visit(exceptHandler: ExceptHandler): T - def visit(arguments: Arguments): T + def visit(arguments: Arguments): T - def visit(arg: Arg): T + def visit(arg: Arg): T - def visit(constant: iconstant): T - def visit(stringConstant: StringConstant): T - def visit(joinedStringConstant: JoinedStringConstant): T - def visit(boolConstant: BoolConstant): T - def visit(intConstant: IntConstant): T - def visit(intConstant: FloatConstant): T - def visit(imaginaryConstant: ImaginaryConstant): T - def visit(noneConstant: NoneConstant.type): T - def visit(ellipsisConstant: EllipsisConstant.type): T + def visit(constant: iconstant): T + def visit(stringConstant: StringConstant): T + def visit(joinedStringConstant: JoinedStringConstant): T + def visit(boolConstant: BoolConstant): T + def visit(intConstant: IntConstant): T + def visit(intConstant: FloatConstant): T + def visit(imaginaryConstant: ImaginaryConstant): T + def visit(noneConstant: NoneConstant.type): T + def visit(ellipsisConstant: EllipsisConstant.type): T - def visit(keyword: Keyword): T + def visit(keyword: Keyword): T - def visit(alias: Alias): T + def visit(alias: Alias): T - def visit(withItem: Withitem): T + def visit(withItem: Withitem): T - def visit(matchCase: MatchCase): T + def visit(matchCase: MatchCase): T - def visit(matchValue: MatchValue): T + def visit(matchValue: MatchValue): T - def visit(matchSingleton: MatchSingleton): T + def visit(matchSingleton: MatchSingleton): T - def visit(matchSequence: MatchSequence): T + def visit(matchSequence: MatchSequence): T - def visit(matchMapping: MatchMapping): T + def visit(matchMapping: MatchMapping): T - def visit(matchClass: MatchClass): T + def visit(matchClass: MatchClass): T - def visit(matchStar: MatchStar): T + def visit(matchStar: MatchStar): T - def visit(matchAs: MatchAs): T + def visit(matchAs: MatchAs): T - def visit(matchOr: MatchOr): T + def visit(matchOr: MatchOr): T - def visit(typeIgnore: TypeIgnore): T -} + def visit(typeIgnore: TypeIgnore): T +end AstVisitor diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala index 454fc3ac..b15c84e8 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala @@ -4,221 +4,204 @@ import CharStreamImpl.{defaultInputBufferSize, defaultMinimumReadSize} import java.io.{IOException, InputStream, InputStreamReader} -object CharStreamImpl { - private val defaultInputBufferSize = 4096 - private val defaultMinimumReadSize = 2048 -} - -class CharStreamImpl(inputStream: InputStream, inputBufferSize: Int, minimumReadSize: Int) extends CharStream { - private val inputReader = new InputStreamReader(inputStream) - - private var inputBuffer = new Array[Char](inputBufferSize) - private var posToLine = new Array[Int](inputBufferSize) - private var posToColumn = new Array[Int](inputBufferSize) - private var readPos = 1 - private var writePos = 1 - private var tokenBeginPos = 1 - private var inputBufferOffset = 0 // From start of inputStream - private var tabSize = 1 - - // The first slot of inputBuffer, posToLine, posToColumn represents the last value - // from the previous chunk red from inputStream. For the very first chunk we - // initialise the values by hand so that requires no special cases in the implementation. - inputBuffer(0) = 'a' // We could have picked any char that is not '\n' or '\r'. - posToLine(0) = 1 - posToColumn(0) = 0 - - def this(inputStream: InputStream) = { - this(inputStream, defaultInputBufferSize, defaultMinimumReadSize) - } - - def getBeginPos: Int = inputBufferOffset + tokenBeginPos - 1 - - private def fillBuffer(): Unit = { - if (readPos == writePos) { - // No more data to read - if (writePos == inputBuffer.length) { - // No more space in inputBuffer - - val keepStartPos = tokenBeginPos - 1 - val charsToKeep = writePos - keepStartPos - if (inputBuffer.length - charsToKeep < minimumReadSize) { - // Resize buffer and move content to the front. - val newBufferLen = charsToKeep + minimumReadSize - - val newInputBuffer = new Array[Char](newBufferLen) - Array.copy(inputBuffer, keepStartPos, newInputBuffer, 0, charsToKeep) - inputBuffer = newInputBuffer - - val newPosToLine = new Array[Int](newBufferLen) - Array.copy(posToLine, keepStartPos, newPosToLine, 0, charsToKeep) - posToLine = newPosToLine - - val newPosToColumn = new Array[Int](newBufferLen) - Array.copy(posToColumn, keepStartPos, newPosToColumn, 0, charsToKeep) - posToColumn = newPosToColumn - } else { - // Enough space left to just move content to the front. - Array.copy(inputBuffer, keepStartPos, inputBuffer, 0, charsToKeep) - Array.copy(posToLine, keepStartPos, posToLine, 0, charsToKeep) - Array.copy(posToColumn, keepStartPos, posToColumn, 0, charsToKeep) - } - writePos = charsToKeep - readPos = readPos - keepStartPos - inputBufferOffset += keepStartPos - tokenBeginPos = 1 - } - - val charsRed = inputReader.read(inputBuffer, writePos, inputBuffer.length - writePos) - if (charsRed != -1) { - writePos += charsRed - } else { - throw new IOException() - } - } - } - - private def lastRedPos: Int = { - readPos - 1 - } - - private def updateLineAndColumn(pos: Int, char: Char, prevChar: Char): Unit = { - - val newLine = - prevChar match { - case '\n' => - true - case '\r' => - if (char == '\n') { - false - } else { - true - } - case _ => - false - } - - if (newLine) { - posToLine(pos) = posToLine(pos - 1) + 1 - posToColumn(pos) = 1 - } else { - posToLine(pos) = posToLine(pos - 1) - posToColumn(pos) = posToColumn(pos - 1) + 1 - } - - if (char == '\t') { - posToColumn(pos) += -1 + (tabSize - (posToColumn(pos) % tabSize)) - } - } - - /** Returns the next character from the selected input. The method of selecting the input is the responsibility of the - * class implementing this interface. Can throw any java.io.IOException. - */ - override def readChar(): Char = { - fillBuffer() - val char = inputBuffer(readPos) - val prevChar = inputBuffer(lastRedPos) - updateLineAndColumn(readPos, char, prevChar) - readPos += 1 - char - } - - /** Returns the column position of the character last read. - * - * @deprecated - * @see - * #getEndColumn - */ - override def getColumn: Int = ??? - - /** Returns the line number of the character last read. - * - * @deprecated - * @see - * #getEndLine - */ - override def getLine: Int = ??? - - /** Returns the column number of the last character for current token (being matched after the last call to - * BeginTOken). - */ - override def getEndColumn: Int = { - posToColumn(lastRedPos) - } - - /** Returns the line number of the last character for current token (being matched after the last call to BeginTOken). - */ - override def getEndLine: Int = { - posToLine(lastRedPos) - } - - /** Returns the column number of the first character for current token (being matched after the last call to - * BeginTOken). - */ - override def getBeginColumn: Int = { - posToColumn(tokenBeginPos) - } - - /** Returns the line number of the first character for current token (being matched after the last call to - * BeginTOken). - */ - override def getBeginLine: Int = { - posToLine(tokenBeginPos) - } - - /** Backs up the input stream by amount steps. Lexer calls this method if it had already read some characters, but - * could not use them to match a (longer) token. So, they will be used again as the prefix of the next token and it - * is the implementation's responsibility to do this right. - */ - override def backup(amount: Int): Unit = { - readPos -= amount - } - - /** Returns the next character that marks the beginning of the next token. All characters must remain in the buffer - * between two successive calls to this method to implement backup correctly. - */ - override def BeginToken(): Char = { - tokenBeginPos = readPos - readChar() - } - - /** Returns a string made up of characters from the marked token beginning to the current buffer position. - * Implementations have the choice of returning anything that they want to. For example, for efficiency, one might - * decide to just return null, which is a valid implementation. - */ - override def GetImage(): String = { - new String(inputBuffer, tokenBeginPos, readPos - tokenBeginPos) - } - - /** Returns an array of characters that make up the suffix of length 'len' for the currently matched token. This is - * used to build up the matched string for use in actions in the case of MORE. A simple and inefficient - * implementation of this is as follows : - * - * { String t = GetImage(); return t.substring(t.length() - len, t.length()).toCharArray(); } - */ - override def GetSuffix(len: Int): Array[Char] = { - val suffix = new Array[Char](len) - Array.copy(inputBuffer, readPos - len, suffix, 0, len) - suffix - } - - /** The lexer calls this function to indicate that it is done with the stream and hence implementations can free any - * resources held by this class. Again, the body of this function can be just empty and it will not affect the - * lexer's operation. - */ - override def Done(): Unit = { - inputReader.close() - } - - override def setTabSize(i: Int): Unit = { - tabSize = i - } - - override def getTabSize: Int = { - tabSize - } - - override def getTrackLineColumn: Boolean = ??? - - override def setTrackLineColumn(trackLineColumn: Boolean): Unit = ??? -} +object CharStreamImpl: + private val defaultInputBufferSize = 4096 + private val defaultMinimumReadSize = 2048 + +class CharStreamImpl(inputStream: InputStream, inputBufferSize: Int, minimumReadSize: Int) + extends CharStream: + private val inputReader = new InputStreamReader(inputStream) + + private var inputBuffer = new Array[Char](inputBufferSize) + private var posToLine = new Array[Int](inputBufferSize) + private var posToColumn = new Array[Int](inputBufferSize) + private var readPos = 1 + private var writePos = 1 + private var tokenBeginPos = 1 + private var inputBufferOffset = 0 // From start of inputStream + private var tabSize = 1 + + // The first slot of inputBuffer, posToLine, posToColumn represents the last value + // from the previous chunk red from inputStream. For the very first chunk we + // initialise the values by hand so that requires no special cases in the implementation. + inputBuffer(0) = 'a' // We could have picked any char that is not '\n' or '\r'. + posToLine(0) = 1 + posToColumn(0) = 0 + + def this(inputStream: InputStream) = + this(inputStream, defaultInputBufferSize, defaultMinimumReadSize) + + def getBeginPos: Int = inputBufferOffset + tokenBeginPos - 1 + + private def fillBuffer(): Unit = + if readPos == writePos then + // No more data to read + if writePos == inputBuffer.length then + // No more space in inputBuffer + + val keepStartPos = tokenBeginPos - 1 + val charsToKeep = writePos - keepStartPos + if inputBuffer.length - charsToKeep < minimumReadSize then + // Resize buffer and move content to the front. + val newBufferLen = charsToKeep + minimumReadSize + + val newInputBuffer = new Array[Char](newBufferLen) + Array.copy(inputBuffer, keepStartPos, newInputBuffer, 0, charsToKeep) + inputBuffer = newInputBuffer + + val newPosToLine = new Array[Int](newBufferLen) + Array.copy(posToLine, keepStartPos, newPosToLine, 0, charsToKeep) + posToLine = newPosToLine + + val newPosToColumn = new Array[Int](newBufferLen) + Array.copy(posToColumn, keepStartPos, newPosToColumn, 0, charsToKeep) + posToColumn = newPosToColumn + else + // Enough space left to just move content to the front. + Array.copy(inputBuffer, keepStartPos, inputBuffer, 0, charsToKeep) + Array.copy(posToLine, keepStartPos, posToLine, 0, charsToKeep) + Array.copy(posToColumn, keepStartPos, posToColumn, 0, charsToKeep) + end if + writePos = charsToKeep + readPos = readPos - keepStartPos + inputBufferOffset += keepStartPos + tokenBeginPos = 1 + end if + + val charsRed = inputReader.read(inputBuffer, writePos, inputBuffer.length - writePos) + if charsRed != -1 then + writePos += charsRed + else + throw new IOException() + + private def lastRedPos: Int = + readPos - 1 + + private def updateLineAndColumn(pos: Int, char: Char, prevChar: Char): Unit = + + val newLine = + prevChar match + case '\n' => + true + case '\r' => + if char == '\n' then + false + else + true + case _ => + false + + if newLine then + posToLine(pos) = posToLine(pos - 1) + 1 + posToColumn(pos) = 1 + else + posToLine(pos) = posToLine(pos - 1) + posToColumn(pos) = posToColumn(pos - 1) + 1 + + if char == '\t' then + posToColumn(pos) += -1 + (tabSize - (posToColumn(pos) % tabSize)) + end updateLineAndColumn + + /** Returns the next character from the selected input. The method of selecting the input is the + * responsibility of the class implementing this interface. Can throw any java.io.IOException. + */ + override def readChar(): Char = + fillBuffer() + val char = inputBuffer(readPos) + val prevChar = inputBuffer(lastRedPos) + updateLineAndColumn(readPos, char, prevChar) + readPos += 1 + char + + /** Returns the column position of the character last read. + * + * @deprecated + * @see + * #getEndColumn + */ + override def getColumn: Int = ??? + + /** Returns the line number of the character last read. + * + * @deprecated + * @see + * #getEndLine + */ + override def getLine: Int = ??? + + /** Returns the column number of the last character for current token (being matched after the + * last call to BeginTOken). + */ + override def getEndColumn: Int = + posToColumn(lastRedPos) + + /** Returns the line number of the last character for current token (being matched after the + * last call to BeginTOken). + */ + override def getEndLine: Int = + posToLine(lastRedPos) + + /** Returns the column number of the first character for current token (being matched after the + * last call to BeginTOken). + */ + override def getBeginColumn: Int = + posToColumn(tokenBeginPos) + + /** Returns the line number of the first character for current token (being matched after the + * last call to BeginTOken). + */ + override def getBeginLine: Int = + posToLine(tokenBeginPos) + + /** Backs up the input stream by amount steps. Lexer calls this method if it had already read + * some characters, but could not use them to match a (longer) token. So, they will be used + * again as the prefix of the next token and it is the implementation's responsibility to do + * this right. + */ + override def backup(amount: Int): Unit = + readPos -= amount + + /** Returns the next character that marks the beginning of the next token. All characters must + * remain in the buffer between two successive calls to this method to implement backup + * correctly. + */ + override def BeginToken(): Char = + tokenBeginPos = readPos + readChar() + + /** Returns a string made up of characters from the marked token beginning to the current buffer + * position. Implementations have the choice of returning anything that they want to. For + * example, for efficiency, one might decide to just return null, which is a valid + * implementation. + */ + override def GetImage(): String = + new String(inputBuffer, tokenBeginPos, readPos - tokenBeginPos) + + /** Returns an array of characters that make up the suffix of length 'len' for the currently + * matched token. This is used to build up the matched string for use in actions in the case of + * MORE. A simple and inefficient implementation of this is as follows : + * + * { String t = GetImage(); return t.substring(t.length() - len, t.length()).toCharArray(); } + */ + override def GetSuffix(len: Int): Array[Char] = + val suffix = new Array[Char](len) + Array.copy(inputBuffer, readPos - len, suffix, 0, len) + suffix + + /** The lexer calls this function to indicate that it is done with the stream and hence + * implementations can free any resources held by this class. Again, the body of this function + * can be just empty and it will not affect the lexer's operation. + */ + override def Done(): Unit = + inputReader.close() + + override def setTabSize(i: Int): Unit = + tabSize = i + + override def getTabSize: Int = + tabSize + + override def getTrackLineColumn: Boolean = ??? + + override def setTrackLineColumn(trackLineColumn: Boolean): Unit = ??? +end CharStreamImpl diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala index 9d84e131..89a06d80 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala @@ -8,23 +8,19 @@ import java.io.{BufferedReader, ByteArrayInputStream, InputStream, Reader} import java.nio.charset.StandardCharsets import scala.jdk.CollectionConverters.* -class PyParser { - private var pythonParser: PythonParser = _ +class PyParser: + private var pythonParser: PythonParser = _ - def parse(code: String): iast = { - parse(new ByteArrayInputStream(code.getBytes(StandardCharsets.UTF_8))) - } + def parse(code: String): iast = + parse(new ByteArrayInputStream(code.getBytes(StandardCharsets.UTF_8))) - def parse(inputStream: InputStream): iast = { - pythonParser = new PythonParser(new CharStreamImpl(inputStream)) - // We start in INDENT_CHECK lexer state because we want to detect indentations - // also for the first line. - pythonParser.token_source.SwitchTo(PythonParserConstants.INDENT_CHECK) - val module = pythonParser.module() - module - } + def parse(inputStream: InputStream): iast = + pythonParser = new PythonParser(new CharStreamImpl(inputStream)) + // We start in INDENT_CHECK lexer state because we want to detect indentations + // also for the first line. + pythonParser.token_source.SwitchTo(PythonParserConstants.INDENT_CHECK) + val module = pythonParser.module() + module - def errors: Iterable[ErrorStatement] = { - pythonParser.getErrors.asScala - } -} + def errors: Iterable[ErrorStatement] = + pythonParser.getErrors.asScala diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala index e97172d0..122d84af 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala @@ -2,7 +2,7 @@ package io.appthreat.pythonparser.ast import io.appthreat.pythonparser.AstVisitor import java.util -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* // This file describes the AST classes. // It tries to stay as close as possible to the AST defined by CPython at @@ -36,23 +36,19 @@ import scala.jdk.CollectionConverters._ /////////////////////////////////////////////////////////////////////////////////////////////////// // AST root trait /////////////////////////////////////////////////////////////////////////////////////////////////// -trait iast { - def accept[T](visitor: AstVisitor[T]): T -} +trait iast: + def accept[T](visitor: AstVisitor[T]): T /////////////////////////////////////////////////////////////////////////////////////////////////// // AST module classes /////////////////////////////////////////////////////////////////////////////////////////////////// trait imod extends iast -case class Module(stmts: CollType[istmt], type_ignores: CollType[TypeIgnore]) extends imod { - def this(stmts: util.ArrayList[istmt], type_ignores: util.ArrayList[TypeIgnore]) = { - this(stmts.asScala, type_ignores.asScala) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class Module(stmts: CollType[istmt], type_ignores: CollType[TypeIgnore]) extends imod: + def this(stmts: util.ArrayList[istmt], type_ignores: util.ArrayList[TypeIgnore]) = + this(stmts.asScala, type_ignores.asScala) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST statement classes @@ -68,22 +64,28 @@ case class FunctionDef( returns: Option[iexpr], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - name: String, - args: Arguments, - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - returns: iexpr, - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(name, args, body.asScala, decorator_list.asScala, Option(returns), Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + name: String, + args: Arguments, + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + returns: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this( + name, + args, + body.asScala, + decorator_list.asScala, + Option(returns), + Option(type_comment), + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end FunctionDef case class AsyncFunctionDef( name: String, @@ -93,22 +95,28 @@ case class AsyncFunctionDef( returns: Option[iexpr], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - name: String, - args: Arguments, - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - returns: iexpr, - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(name, args, body.asScala, decorator_list.asScala, Option(returns), Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + name: String, + args: Arguments, + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + returns: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this( + name, + args, + body.asScala, + decorator_list.asScala, + Option(returns), + Option(type_comment), + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end AsyncFunctionDef case class ClassDef( name: String, @@ -117,59 +125,58 @@ case class ClassDef( body: CollType[istmt], decorator_list: CollType[iexpr], attributeProvider: AttributeProvider -) extends istmt { - def this( - name: String, - bases: util.ArrayList[iexpr], - keywords: util.ArrayList[Keyword], - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - attributeProvider: AttributeProvider - ) = { - this(name, bases.asScala, keywords.asScala, body.asScala, decorator_list.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Return(value: Option[iexpr], attributeProvider: AttributeProvider) extends istmt { - def this(value: iexpr, attributeProvider: AttributeProvider) = { - this(Option(value), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Delete(targets: CollType[iexpr], attributeProvider: AttributeProvider) extends istmt { - def this(targets: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(targets.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + name: String, + bases: util.ArrayList[iexpr], + keywords: util.ArrayList[Keyword], + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this( + name, + bases.asScala, + keywords.asScala, + body.asScala, + decorator_list.asScala, + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end ClassDef + +case class Return(value: Option[iexpr], attributeProvider: AttributeProvider) extends istmt: + def this(value: iexpr, attributeProvider: AttributeProvider) = + this(Option(value), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Delete(targets: CollType[iexpr], attributeProvider: AttributeProvider) extends istmt: + def this(targets: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(targets.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Assign( targets: CollType[iexpr], value: iexpr, typeComment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this(targets: util.ArrayList[iexpr], value: iexpr, attributeProvider: AttributeProvider) = { - this(targets.asScala, value, None, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class AugAssign(target: iexpr, op: ioperator, value: iexpr, attributeProvider: AttributeProvider) extends istmt { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this(targets: util.ArrayList[iexpr], value: iexpr, attributeProvider: AttributeProvider) = + this(targets.asScala, value, None, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class AugAssign( + target: iexpr, + op: ioperator, + value: iexpr, + attributeProvider: AttributeProvider +) extends istmt: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class AnnAssign( target: iexpr, @@ -177,14 +184,17 @@ case class AnnAssign( value: Option[iexpr], simple: Boolean, attributeProvider: AttributeProvider -) extends istmt { - def this(target: iexpr, annotation: iexpr, value: iexpr, simple: Boolean, attributeProvider: AttributeProvider) = { - this(target, annotation, Option(value), simple, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + target: iexpr, + annotation: iexpr, + value: iexpr, + simple: Boolean, + attributeProvider: AttributeProvider + ) = + this(target, annotation, Option(value), simple, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class For( target: iexpr, @@ -193,21 +203,19 @@ case class For( orelse: CollType[istmt], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - target: iexpr, - iter: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + target: iexpr, + iter: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end For case class AsyncFor( target: iexpr, @@ -216,107 +224,97 @@ case class AsyncFor( orelse: CollType[istmt], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - target: iexpr, - iter: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class While(test: iexpr, body: CollType[istmt], orelse: CollType[istmt], attributeProvider: AttributeProvider) - extends istmt { - def this( - test: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = { - this(test, body.asScala, orelse.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class If(test: iexpr, body: CollType[istmt], orelse: CollType[istmt], attributeProvider: AttributeProvider) - extends istmt { - def this( - test: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = { - this(test, body.asScala, orelse.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + target: iexpr, + iter: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end AsyncFor + +case class While( + test: iexpr, + body: CollType[istmt], + orelse: CollType[istmt], + attributeProvider: AttributeProvider +) extends istmt: + def this( + test: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(test, body.asScala, orelse.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class If( + test: iexpr, + body: CollType[istmt], + orelse: CollType[istmt], + attributeProvider: AttributeProvider +) extends istmt: + def this( + test: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(test, body.asScala, orelse.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class With( items: CollType[Withitem], body: CollType[istmt], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - items: util.ArrayList[Withitem], - body: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(items.asScala, body.asScala, Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + items: util.ArrayList[Withitem], + body: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(items.asScala, body.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class AsyncWith( items: CollType[Withitem], body: CollType[istmt], type_comment: Option[String], attributeProvider: AttributeProvider -) extends istmt { - def this( - items: util.ArrayList[Withitem], - body: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = { - this(items.asScala, body.asScala, Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Match(subject: iexpr, cases: CollType[MatchCase], attributeProvider: AttributeProvider) extends istmt { - def this(subject: iexpr, cases: util.List[MatchCase], attributeProvider: AttributeProvider) = { - this(subject, cases.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Raise(exc: Option[iexpr], cause: Option[iexpr], attributeProvider: AttributeProvider) extends istmt { - def this(exc: iexpr, cause: iexpr, attributeProvider: AttributeProvider) = { - this(Option(exc), Option(cause), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + items: util.ArrayList[Withitem], + body: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(items.asScala, body.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Match(subject: iexpr, cases: CollType[MatchCase], attributeProvider: AttributeProvider) + extends istmt: + def this(subject: iexpr, cases: util.List[MatchCase], attributeProvider: AttributeProvider) = + this(subject, cases.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Raise(exc: Option[iexpr], cause: Option[iexpr], attributeProvider: AttributeProvider) + extends istmt: + def this(exc: iexpr, cause: iexpr, attributeProvider: AttributeProvider) = + this(Option(exc), Option(cause), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Try( body: CollType[istmt], @@ -324,93 +322,76 @@ case class Try( orelse: CollType[istmt], finalbody: CollType[istmt], attributeProvider: AttributeProvider -) extends istmt { - def this( - body: util.ArrayList[istmt], - handlers: util.ArrayList[ExceptHandler], - orelse: util.ArrayList[istmt], - finalbody: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = { - this(body.asScala, handlers.asScala, orelse.asScala, finalbody.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Assert(test: iexpr, msg: Option[iexpr], attributeProvider: AttributeProvider) extends istmt { - def this(test: iexpr, msg: iexpr, attributeProvider: AttributeProvider) = { - this(test, Option(msg), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Import(names: CollType[Alias], attributeProvider: AttributeProvider) extends istmt { - def this(names: util.ArrayList[Alias], attributeProvider: AttributeProvider) = { - this(names.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class ImportFrom(module: Option[String], names: CollType[Alias], level: Int, attributeProvider: AttributeProvider) - extends istmt { - def this(module: String, names: util.ArrayList[Alias], level: Int, attributeProvider: AttributeProvider) = { - this(Option(module), names.asScala, level, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Global(names: CollType[String], attributeProvider: AttributeProvider) extends istmt { - def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = { - this(names.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Nonlocal(names: CollType[String], attributeProvider: AttributeProvider) extends istmt { - def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = { - this(names.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Expr(value: iexpr, attributeProvider: AttributeProvider) extends istmt { - def this(value: iexpr) = { - this(value, value.attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Pass(attributeProvider: AttributeProvider) extends istmt { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Break(attributeProvider: AttributeProvider) extends istmt { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Continue(attributeProvider: AttributeProvider) extends istmt { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends istmt: + def this( + body: util.ArrayList[istmt], + handlers: util.ArrayList[ExceptHandler], + orelse: util.ArrayList[istmt], + finalbody: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(body.asScala, handlers.asScala, orelse.asScala, finalbody.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Assert(test: iexpr, msg: Option[iexpr], attributeProvider: AttributeProvider) + extends istmt: + def this(test: iexpr, msg: iexpr, attributeProvider: AttributeProvider) = + this(test, Option(msg), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Import(names: CollType[Alias], attributeProvider: AttributeProvider) extends istmt: + def this(names: util.ArrayList[Alias], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class ImportFrom( + module: Option[String], + names: CollType[Alias], + level: Int, + attributeProvider: AttributeProvider +) extends istmt: + def this( + module: String, + names: util.ArrayList[Alias], + level: Int, + attributeProvider: AttributeProvider + ) = + this(Option(module), names.asScala, level, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Global(names: CollType[String], attributeProvider: AttributeProvider) extends istmt: + def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Nonlocal(names: CollType[String], attributeProvider: AttributeProvider) extends istmt: + def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Expr(value: iexpr, attributeProvider: AttributeProvider) extends istmt: + def this(value: iexpr) = + this(value, value.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Pass(attributeProvider: AttributeProvider) extends istmt: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Break(attributeProvider: AttributeProvider) extends istmt: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Continue(attributeProvider: AttributeProvider) extends istmt: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This is the python2 raise statement. // It is different enough from the python3 version to justify an @@ -419,192 +400,187 @@ case class Continue(attributeProvider: AttributeProvider) extends istmt { // represented as and instance of this class. Only if the syntax // of a raise does not match the python3 syntax we represent it // as such a class. -case class RaiseP2(typ: Option[iexpr], inst: Option[iexpr], tback: Option[iexpr], attributeProvider: AttributeProvider) - extends istmt { - def this(typ: iexpr, inst: iexpr, tback: iexpr, attributeProvider: AttributeProvider) = { - this(Option(typ), Option(inst), Option(tback), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class RaiseP2( + typ: Option[iexpr], + inst: Option[iexpr], + tback: Option[iexpr], + attributeProvider: AttributeProvider +) extends istmt: + def this(typ: iexpr, inst: iexpr, tback: iexpr, attributeProvider: AttributeProvider) = + this(Option(typ), Option(inst), Option(tback), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This statement is not part of the CPython AST definition and // was added to represent parse errors inline with valid AST // statements. -case class ErrorStatement(exception: Exception, attributeProvider: AttributeProvider) extends istmt { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class ErrorStatement(exception: Exception, attributeProvider: AttributeProvider) extends istmt: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST expression classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait iexpr extends iast with iattributes -case class BoolOp(op: iboolop, values: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr { - def this(op: iboolop, values: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(op, values.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class NamedExpr(target: iexpr, value: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class BinOp(left: iexpr, op: ioperator, right: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class UnaryOp(op: iunaryop, operand: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Lambda(args: Arguments, body: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class IfExp(test: iexpr, body: iexpr, orelse: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Dict(keys: CollType[Option[iexpr]], values: CollType[iexpr], attributeProvider: AttributeProvider) - extends iexpr { - def this(keys: util.ArrayList[iexpr], values: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(keys.asScala.map(Option.apply), values.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Set(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr { - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(elts.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class ListComp(elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider) - extends iexpr { - def this(elt: iexpr, generators: util.ArrayList[Comprehension], attributeProvider: AttributeProvider) = { - this(elt, generators.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class SetComp(elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider) - extends iexpr { - def this(elt: iexpr, generators: util.ArrayList[Comprehension], attributeProvider: AttributeProvider) = { - this(elt, generators.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class DictComp(key: iexpr, value: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider) - extends iexpr { - def this( - key: iexpr, - value: iexpr, - generators: util.ArrayList[Comprehension], - attributeProvider: AttributeProvider - ) = { - this(key, value, generators.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class GeneratorExp(elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider) - extends iexpr { - def this(elt: iexpr, generators: util.ArrayList[Comprehension], attributeProvider: AttributeProvider) = { - this(elt, generators.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Await(value: iexpr, attributeProvider: AttributeProvider) extends iexpr { - def this(value: iexpr) = { - this(value, value.attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Yield(value: Option[iexpr], attributeProvider: AttributeProvider) extends iexpr { - def this(value: iexpr, attributeProvider: AttributeProvider) = { - this(Option(value), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class YieldFrom(value: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class BoolOp(op: iboolop, values: CollType[iexpr], attributeProvider: AttributeProvider) + extends iexpr: + def this(op: iboolop, values: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(op, values.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class NamedExpr(target: iexpr, value: iexpr, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class BinOp(left: iexpr, op: ioperator, right: iexpr, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class UnaryOp(op: iunaryop, operand: iexpr, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Lambda(args: Arguments, body: iexpr, attributeProvider: AttributeProvider) extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class IfExp(test: iexpr, body: iexpr, orelse: iexpr, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Dict( + keys: CollType[Option[iexpr]], + values: CollType[iexpr], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + keys: util.ArrayList[iexpr], + values: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this(keys.asScala.map(Option.apply), values.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Set(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class ListComp( + elt: iexpr, + generators: CollType[Comprehension], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class SetComp( + elt: iexpr, + generators: CollType[Comprehension], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class DictComp( + key: iexpr, + value: iexpr, + generators: CollType[Comprehension], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + key: iexpr, + value: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(key, value, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class GeneratorExp( + elt: iexpr, + generators: CollType[Comprehension], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Await(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: + def this(value: iexpr) = + this(value, value.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Yield(value: Option[iexpr], attributeProvider: AttributeProvider) extends iexpr: + def this(value: iexpr, attributeProvider: AttributeProvider) = + this(Option(value), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class YieldFrom(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Compare( left: iexpr, ops: CollType[icompop], comparators: CollType[iexpr], attributeProvider: AttributeProvider -) extends iexpr { - def this( - left: iexpr, - ops: util.ArrayList[icompop], - comparators: util.ArrayList[iexpr], - attributeProvider: AttributeProvider - ) = { - this(left, ops.asScala, comparators.asScala, attributeProvider) - } - - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Call(func: iexpr, args: CollType[iexpr], keywords: CollType[Keyword], attributeProvider: AttributeProvider) - extends iexpr { - def this( - func: iexpr, - args: util.ArrayList[iexpr], - keywords: util.ArrayList[Keyword], - attributeProvider: AttributeProvider - ) = { - this(func, args.asScala, keywords.asScala, attributeProvider) - } - - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends iexpr: + def this( + left: iexpr, + ops: util.ArrayList[icompop], + comparators: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this(left, ops.asScala, comparators.asScala, attributeProvider) + + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Call( + func: iexpr, + args: CollType[iexpr], + keywords: CollType[Keyword], + attributeProvider: AttributeProvider +) extends iexpr: + def this( + func: iexpr, + args: util.ArrayList[iexpr], + keywords: util.ArrayList[Keyword], + attributeProvider: AttributeProvider + ) = + this(func, args.asScala, keywords.asScala, attributeProvider) + + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // In addition to the CPython version of this class we also stored // whether the value expression was followed by "=" in "equalSign". @@ -617,91 +593,81 @@ case class FormattedValue( format_spec: Option[String], equalSign: Boolean, attributeProvider: AttributeProvider -) extends iexpr { - def this( - value: iexpr, - conversion: Int, - format_spec: String, - equalSign: Boolean, - attributeProvider: AttributeProvider - ) = { - this(value, conversion, Option(format_spec), equalSign, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends iexpr: + def this( + value: iexpr, + conversion: Int, + format_spec: String, + equalSign: Boolean, + attributeProvider: AttributeProvider + ) = + this(value, conversion, Option(format_spec), equalSign, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // In addition to the CPython version of this class we have the fields // "quote" which stores the kind of quote used and "prefix" // which stores the exact prefix used with the string. -case class JoinedString(values: CollType[iexpr], quote: String, prefix: String, attributeProvider: AttributeProvider) - extends iexpr { - def this(values: util.List[iexpr], quote: String, prefix: String, attributeProvider: AttributeProvider) = { - this(values.asScala, quote, prefix, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Constant(value: iconstant, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Attribute(value: iexpr, attr: String, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Subscript(value: iexpr, slice: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Starred(value: iexpr, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Name(id: String, attributeProvider: AttributeProvider) extends iexpr { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class List(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr { - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(elts.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Tuple(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr { - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = { - this(elts.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class Slice(lower: Option[iexpr], upper: Option[iexpr], step: Option[iexpr], attributeProvider: AttributeProvider) - extends iexpr { - def this(lower: iexpr, upper: iexpr, step: iexpr, attributeProvider: AttributeProvider) = { - this(Option(lower), Option(upper), Option(step), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class JoinedString( + values: CollType[iexpr], + quote: String, + prefix: String, + attributeProvider: AttributeProvider +) extends iexpr: + def this( + values: util.List[iexpr], + quote: String, + prefix: String, + attributeProvider: AttributeProvider + ) = + this(values.asScala, quote, prefix, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Constant(value: iconstant, attributeProvider: AttributeProvider) extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Attribute(value: iexpr, attr: String, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Subscript(value: iexpr, slice: iexpr, attributeProvider: AttributeProvider) + extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Starred(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Name(id: String, attributeProvider: AttributeProvider) extends iexpr: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class List(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Tuple(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class Slice( + lower: Option[iexpr], + upper: Option[iexpr], + step: Option[iexpr], + attributeProvider: AttributeProvider +) extends iexpr: + def this(lower: iexpr, upper: iexpr, step: iexpr, attributeProvider: AttributeProvider) = + this(Option(lower), Option(upper), Option(step), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This class is not part of the CPython AST definition at // https://docs.python.org/3/library/ast.html @@ -709,200 +675,137 @@ case class Slice(lower: Option[iexpr], upper: Option[iexpr], step: Option[iexpr] // this extra kind of expression. // A StringExpList must always have at least 2 elements and its elements must // be either a Constant which contains a StringConstant or a JoinedString. -case class StringExpList(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr { - assert(elts.size >= 2) - def this(elts: util.ArrayList[iexpr]) = { - this(elts.asScala, elts.asScala.head.attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class StringExpList(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: + assert(elts.size >= 2) + def this(elts: util.ArrayList[iexpr]) = + this(elts.asScala, elts.asScala.head.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST boolop classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait iboolop extends iast -object And extends iboolop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +object And extends iboolop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) -case object Or extends iboolop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case object Or extends iboolop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST operator classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait ioperator extends iast -case object Add extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Sub extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Mult extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object MatMult extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Div extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Mod extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Pow extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object LShift extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object RShift extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object BitOr extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object BitXor extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object BitAnd extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object FloorDiv extends ioperator { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case object Add extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Sub extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Mult extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object MatMult extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Div extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Mod extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Pow extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object LShift extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object RShift extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object BitOr extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object BitXor extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object BitAnd extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object FloorDiv extends ioperator: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST unaryop classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait iunaryop extends iast -case object Invert extends iunaryop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case object Not extends iunaryop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case object UAdd extends iunaryop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case object USub extends iunaryop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case object Invert extends iunaryop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case object Not extends iunaryop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case object UAdd extends iunaryop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case object USub extends iunaryop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST compop classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait icompop extends iast -case object Eq extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object NotEq extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Lt extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object LtE extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Gt extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object GtE extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object Is extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object IsNot extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object In extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object NotIn extends icompop { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case object Eq extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object NotEq extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Lt extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object LtE extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Gt extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object GtE extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object Is extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object IsNot extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object In extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object NotIn extends icompop: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST comprehension classes /////////////////////////////////////////////////////////////////////////////////////////////////// -case class Comprehension(target: iexpr, iter: iexpr, ifs: CollType[iexpr], is_async: Boolean) extends iast { - def this(target: iexpr, iter: iexpr, ifs: util.ArrayList[iexpr], is_async: Boolean) = { - this(target, iter, ifs.asScala, is_async) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class Comprehension(target: iexpr, iter: iexpr, ifs: CollType[iexpr], is_async: Boolean) + extends iast: + def this(target: iexpr, iter: iexpr, ifs: util.ArrayList[iexpr], is_async: Boolean) = + this(target, iter, ifs.asScala, is_async) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST exceptHandler classes @@ -913,14 +816,16 @@ case class ExceptHandler( body: CollType[istmt], attributeProvider: AttributeProvider ) extends iast - with iattributes { - def this(typ: iexpr, name: String, body: util.ArrayList[istmt], attributeProvider: AttributeProvider) = { - this(Option(typ), Option(name), body.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} + with iattributes: + def this( + typ: iexpr, + name: String, + body: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(Option(typ), Option(name), body.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST arguments classes @@ -940,30 +845,28 @@ case class Arguments( kw_defaults: CollType[Option[iexpr]], kw_arg: Option[Arg], defaults: CollType[iexpr] -) extends iast { - def this( - posonlyargs: util.List[Arg], - args: util.List[Arg], - vararg: Arg, - kwonlyargs: util.List[Arg], - kw_defaults: util.List[iexpr], - kw_arg: Arg, - defaults: util.List[iexpr] - ) = { - this( - posonlyargs.asScala, - args.asScala, - Option(vararg), - kwonlyargs.asScala, - kw_defaults.asScala.map(Option.apply), - Option(kw_arg), - defaults.asScala - ) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends iast: + def this( + posonlyargs: util.List[Arg], + args: util.List[Arg], + vararg: Arg, + kwonlyargs: util.List[Arg], + kw_defaults: util.List[iexpr], + kw_arg: Arg, + defaults: util.List[iexpr] + ) = + this( + posonlyargs.asScala, + args.asScala, + Option(vararg), + kwonlyargs.asScala, + kw_defaults.asScala.map(Option.apply), + Option(kw_arg), + defaults.asScala + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +end Arguments /////////////////////////////////////////////////////////////////////////////////////////////////// // AST arg classes @@ -974,110 +877,91 @@ case class Arg( type_comment: Option[String], attributeProvider: AttributeProvider ) extends iast - with iattributes { - def this(arg: String, annotation: iexpr, type_comment: String, attributeProvider: AttributeProvider) = { - this(arg, Option(annotation), Option(type_comment), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} + with iattributes: + def this( + arg: String, + annotation: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(arg, Option(annotation), Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST keyword classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class Keyword(arg: Option[String], value: iexpr, attributeProvider: AttributeProvider) extends iast - with iattributes { - def this(arg: String, value: iexpr, attributeProvider: AttributeProvider) = { - this(Option(arg), value, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} + with iattributes: + def this(arg: String, value: iexpr, attributeProvider: AttributeProvider) = + this(Option(arg), value, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST alias classes /////////////////////////////////////////////////////////////////////////////////////////////////// -case class Alias(name: String, asName: Option[String]) extends iast { - def this(name: String, asName: String) = { - this(name, Option(asName)) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class Alias(name: String, asName: Option[String]) extends iast: + def this(name: String, asName: String) = + this(name, Option(asName)) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST withitem classes /////////////////////////////////////////////////////////////////////////////////////////////////// -case class Withitem(context_expr: iexpr, optional_vars: Option[iexpr]) extends iast { - def this(context_expr: iexpr, optional_vars: iexpr) = { - this(context_expr, Option(optional_vars)) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class Withitem(context_expr: iexpr, optional_vars: Option[iexpr]) extends iast: + def this(context_expr: iexpr, optional_vars: iexpr) = + this(context_expr, Option(optional_vars)) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST match_case classes /////////////////////////////////////////////////////////////////////////////////////////////////// -case class MatchCase(pattern: ipattern, guard: Option[iexpr], body: CollType[istmt]) extends iast { - def this(pattern: ipattern, guard: iexpr, body: util.List[istmt]) = { - this(pattern, Option(guard), body.asScala) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class MatchCase(pattern: ipattern, guard: Option[iexpr], body: CollType[istmt]) extends iast: + def this(pattern: ipattern, guard: iexpr, body: util.List[istmt]) = + this(pattern, Option(guard), body.asScala) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST pattern classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait ipattern extends iast with iattributes -case class MatchValue(value: iexpr, attributeProvider: AttributeProvider) extends ipattern { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class MatchSingleton(value: iconstant, attributeProvider: AttributeProvider) extends ipattern { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class MatchSequence(patterns: CollType[ipattern], attributeProvider: AttributeProvider) extends ipattern { - def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = { - this(patterns.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class MatchValue(value: iexpr, attributeProvider: AttributeProvider) extends ipattern: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class MatchSingleton(value: iconstant, attributeProvider: AttributeProvider) extends ipattern: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class MatchSequence(patterns: CollType[ipattern], attributeProvider: AttributeProvider) + extends ipattern: + def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = + this(patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchMapping( keys: CollType[iexpr], patterns: CollType[ipattern], rest: Option[String], attributeProvider: AttributeProvider -) extends ipattern { - def this( - keys: util.List[iexpr], - patterns: util.List[ipattern], - rest: String, - attributeProvider: AttributeProvider - ) = { - this(keys.asScala, patterns.asScala, Option(rest), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends ipattern: + def this( + keys: util.List[iexpr], + patterns: util.List[ipattern], + rest: String, + attributeProvider: AttributeProvider + ) = + this(keys.asScala, patterns.asScala, Option(rest), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchClass( cls: iexpr, @@ -1085,113 +969,86 @@ case class MatchClass( kwd_attrs: CollType[String], kwd_patterns: CollType[ipattern], attributeProvider: AttributeProvider -) extends ipattern { - def this( - cls: iexpr, - patterns: util.List[ipattern], - kwd_attrs: util.List[String], - kwd_patterns: util.List[ipattern], - attributeProvider: AttributeProvider - ) = { - this(cls, patterns.asScala, kwd_attrs.asScala, kwd_patterns.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class MatchStar(name: Option[String], attributeProvider: AttributeProvider) extends ipattern { - def this(name: String, attributeProvider: AttributeProvider) = { - this(Option(name), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class MatchAs(pattern: Option[ipattern], name: Option[String], attributeProvider: AttributeProvider) - extends ipattern { - def this(pattern: ipattern, name: String, attributeProvider: AttributeProvider) = { - this(Option(pattern), Option(name), attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} - -case class MatchOr(patterns: CollType[ipattern], attributeProvider: AttributeProvider) extends ipattern { - def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = { - this(patterns.asScala, attributeProvider) - } - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +) extends ipattern: + def this( + cls: iexpr, + patterns: util.List[ipattern], + kwd_attrs: util.List[String], + kwd_patterns: util.List[ipattern], + attributeProvider: AttributeProvider + ) = + this(cls, patterns.asScala, kwd_attrs.asScala, kwd_patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class MatchStar(name: Option[String], attributeProvider: AttributeProvider) extends ipattern: + def this(name: String, attributeProvider: AttributeProvider) = + this(Option(name), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class MatchAs( + pattern: Option[ipattern], + name: Option[String], + attributeProvider: AttributeProvider +) extends ipattern: + def this(pattern: ipattern, name: String, attributeProvider: AttributeProvider) = + this(Option(pattern), Option(name), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) + +case class MatchOr(patterns: CollType[ipattern], attributeProvider: AttributeProvider) + extends ipattern: + def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = + this(patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST type_ignore classes /////////////////////////////////////////////////////////////////////////////////////////////////// -case class TypeIgnore(lineno: Int, tag: String) extends iast { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class TypeIgnore(lineno: Int, tag: String) extends iast: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST attributes classes /////////////////////////////////////////////////////////////////////////////////////////////////// -trait iattributes { - val attributeProvider: AttributeProvider - def lineno: Int = attributeProvider.lineno - def col_offset: Int = attributeProvider.col_offset - def input_offset: Int = attributeProvider.input_offset - def end_lineno: Int = attributeProvider.end_lineno - def end_col_offset: Int = attributeProvider.end_col_offset - def end_input_offset: Int = attributeProvider.end_input_offset -} +trait iattributes: + val attributeProvider: AttributeProvider + def lineno: Int = attributeProvider.lineno + def col_offset: Int = attributeProvider.col_offset + def input_offset: Int = attributeProvider.input_offset + def end_lineno: Int = attributeProvider.end_lineno + def end_col_offset: Int = attributeProvider.end_col_offset + def end_input_offset: Int = attributeProvider.end_input_offset /////////////////////////////////////////////////////////////////////////////////////////////////// // AST constant classes /////////////////////////////////////////////////////////////////////////////////////////////////// sealed trait iconstant extends iast -case class StringConstant(value: String, quote: String, prefix: String) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case class JoinedStringConstant(value: String) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case class BoolConstant(value: Boolean) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case class IntConstant(value: String) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case class FloatConstant(value: String) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case class ImaginaryConstant(value: String) extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object NoneConstant extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} -case object EllipsisConstant extends iconstant { - override def accept[T](visitor: AstVisitor[T]): T = { - visitor.visit(this) - } -} +case class StringConstant(value: String, quote: String, prefix: String) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case class JoinedStringConstant(value: String) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case class BoolConstant(value: Boolean) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case class IntConstant(value: String) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case class FloatConstant(value: String) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case class ImaginaryConstant(value: String) extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object NoneConstant extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) +case object EllipsisConstant extends iconstant: + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala index 4d7dda56..f8b63e19 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala @@ -2,67 +2,51 @@ package io.appthreat.pythonparser.ast import io.appthreat.pythonparser.Token -trait AttributeProvider { - def lineno: Int - def col_offset: Int - def input_offset: Int - def end_lineno: Int - def end_col_offset: Int - def end_input_offset: Int +trait AttributeProvider: + def lineno: Int + def col_offset: Int + def input_offset: Int + def end_lineno: Int + def end_col_offset: Int + def end_input_offset: Int - override def toString: String = { - s"$lineno,$col_offset,$end_lineno,$end_col_offset" - } -} + override def toString: String = + s"$lineno,$col_offset,$end_lineno,$end_col_offset" -class TokenAttributeProvider(startToken: Token, endToken: Token) extends AttributeProvider { - override def lineno: Int = { - startToken.beginLine - } +class TokenAttributeProvider(startToken: Token, endToken: Token) extends AttributeProvider: + override def lineno: Int = + startToken.beginLine - override def col_offset: Int = { - startToken.beginColumn - } + override def col_offset: Int = + startToken.beginColumn - override def input_offset: Int = { - startToken.startPos - } + override def input_offset: Int = + startToken.startPos - override def end_lineno: Int = { - endToken.endLine - } + override def end_lineno: Int = + endToken.endLine - override def end_col_offset: Int = { - endToken.endColumn - } + override def end_col_offset: Int = + endToken.endColumn - override def end_input_offset: Int = { - endToken.endPos - } -} + override def end_input_offset: Int = + endToken.endPos -class NodeAttributeProvider(astNode: iattributes, endToken: Token) extends AttributeProvider { - override def lineno: Int = { - astNode.lineno - } +class NodeAttributeProvider(astNode: iattributes, endToken: Token) extends AttributeProvider: + override def lineno: Int = + astNode.lineno - override def col_offset: Int = { - astNode.col_offset - } + override def col_offset: Int = + astNode.col_offset - override def input_offset: Int = { - astNode.input_offset - } + override def input_offset: Int = + astNode.input_offset - override def end_lineno: Int = { - endToken.endLine - } + override def end_lineno: Int = + endToken.endLine - override def end_col_offset: Int = { - endToken.endColumn - } + override def end_col_offset: Int = + endToken.endColumn - override def end_input_offset: Int = { - endToken.endPos - } -} + override def end_input_offset: Int = + endToken.endPos diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala index 77a61799..d2b5152d 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala @@ -2,6 +2,5 @@ package io.appthreat.pythonparser import scala.collection.mutable -package object ast { - type CollType[T] = mutable.Seq[T] -} +package object ast: + type CollType[T] = mutable.Seq[T] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala index 85dbd8e5..0e566733 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala @@ -9,79 +9,77 @@ import overflowdb.SchemaViolationException case class AstEdge(src: NewNode, dst: NewNode) -enum ValidationMode { - case Enabled, Disabled -} - -object Ast { - - private val logger = LoggerFactory.getLogger(getClass) - - def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = Ast(Vector.empty :+ node) - def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) - - /** Copy nodes/edges of given `AST` into the given `diffGraph`. - */ - def storeInDiffGraph(ast: Ast, diffGraph: DiffGraphBuilder): Unit = { - - setOrderWhereNotSet(ast) - - ast.nodes.foreach { node => - diffGraph.addNode(node) - } - ast.edges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.AST) - } - ast.conditionEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CONDITION) - } - ast.receiverEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.RECEIVER) - } - ast.refEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.REF) - } - - ast.argEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.ARGUMENT) - } - - ast.bindsEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) - } - } - - def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit - withSchemaValidation: ValidationMode - ): Unit = if ( - withSchemaValidation == ValidationMode.Enabled && - !(src.isValidOutNeighbor(edge, dst) && dst.isValidInNeighbor(edge, src)) - ) { - throw new SchemaViolationException( - s"Malformed AST detected: (${src.label()}) -[$edge]-> (${dst.label()}) violates the schema." - ) - } - - /** For all `order` fields that are unset, derive the `order` field automatically by determining the position of the - * child among its siblings. - */ - private def setOrderWhereNotSet(ast: Ast): Unit = { - ast.root.collect { case r: AstNodeNew => - if (r.order == PropertyDefaults.Order) { - r.order = 1 - } - } - val siblings = ast.edges.groupBy(_.src).map { case (_, edgeToChild) => edgeToChild.map(_.dst) } - siblings.foreach { children => - children.zipWithIndex.collect { case (c: AstNodeNew, i) => - if (c.order == PropertyDefaults.Order) { - c.order = i + 1 +enum ValidationMode: + case Enabled, Disabled + +object Ast: + + private val logger = LoggerFactory.getLogger(getClass) + + def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = + Ast(Vector.empty :+ node) + def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) + + /** Copy nodes/edges of given `AST` into the given `diffGraph`. + */ + def storeInDiffGraph(ast: Ast, diffGraph: DiffGraphBuilder): Unit = + + setOrderWhereNotSet(ast) + + ast.nodes.foreach { node => + diffGraph.addNode(node) + } + ast.edges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.AST) + } + ast.conditionEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CONDITION) + } + ast.receiverEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.RECEIVER) + } + ast.refEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.REF) } - } - } - } -} + ast.argEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.ARGUMENT) + } + + ast.bindsEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) + } + end storeInDiffGraph + + def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit + withSchemaValidation: ValidationMode + ): Unit = + if + withSchemaValidation == ValidationMode.Enabled && + !(src.isValidOutNeighbor(edge, dst) && dst.isValidInNeighbor(edge, src)) + then + throw new SchemaViolationException( + s"Malformed AST detected: (${src.label()}) -[$edge]-> (${dst.label()}) violates the schema." + ) + + /** For all `order` fields that are unset, derive the `order` field automatically by determining + * the position of the child among its siblings. + */ + private def setOrderWhereNotSet(ast: Ast): Unit = + ast.root.collect { case r: AstNodeNew => + if r.order == PropertyDefaults.Order then + r.order = 1 + } + val siblings = ast.edges.groupBy(_.src).map { case (_, edgeToChild) => + edgeToChild.map(_.dst) + } + siblings.foreach { children => + children.zipWithIndex.collect { case (c: AstNodeNew, i) => + if c.order == PropertyDefaults.Order then + c.order = i + 1 + } + } +end Ast case class Ast( nodes: collection.Seq[ @@ -93,172 +91,162 @@ case class Ast( bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, argEdges: collection.Seq[AstEdge] = Vector.empty -)(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) { - - def root: Option[NewNode] = nodes.headOption - - def rightMostLeaf: Option[NewNode] = nodes.lastOption - - /** AST that results when adding `other` as a child to this AST. `other` is connected to this AST's root node. - */ - def withChild(other: Ast): Ast = { - Ast( - nodes ++ other.nodes, - edges = edges ++ other.edges ++ root.toList.flatMap(r => - other.root.toList.map { rc => - Ast.neighbourValidation(r, rc, EdgeTypes.AST) - AstEdge(r, rc) +)(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled): + + def root: Option[NewNode] = nodes.headOption + + def rightMostLeaf: Option[NewNode] = nodes.lastOption + + /** AST that results when adding `other` as a child to this AST. `other` is connected to this + * AST's root node. + */ + def withChild(other: Ast): Ast = + Ast( + nodes ++ other.nodes, + edges = edges ++ other.edges ++ root.toList.flatMap(r => + other.root.toList.map { rc => + Ast.neighbourValidation(r, rc, EdgeTypes.AST) + AstEdge(r, rc) + } + ), + conditionEdges = conditionEdges ++ other.conditionEdges, + argEdges = argEdges ++ other.argEdges, + receiverEdges = receiverEdges ++ other.receiverEdges, + refEdges = refEdges ++ other.refEdges, + bindsEdges = bindsEdges ++ other.bindsEdges + ) + + def merge(other: Ast): Ast = + Ast( + nodes ++ other.nodes, + edges = edges ++ other.edges, + conditionEdges = conditionEdges ++ other.conditionEdges, + argEdges = argEdges ++ other.argEdges, + receiverEdges = receiverEdges ++ other.receiverEdges, + refEdges = refEdges ++ other.refEdges, + bindsEdges = bindsEdges ++ other.bindsEdges + ) + + /** AST that results when adding all ASTs in `asts` as children, that is, connecting them to the + * root node of this AST. + */ + def withChildren(asts: collection.Seq[Ast]): Ast = + if asts.isEmpty then + this + else + // we do this iteratively as a recursive solution which will fail with + // a StackOverflowException if there are too many elements in .tail. + var ast = withChild(asts.head) + asts.tail.foreach(c => ast = ast.withChild(c)) + ast + + def withConditionEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION) + this.copy(conditionEdges = conditionEdges ++ List(AstEdge(src, dst))) + + def withRefEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.REF) + this.copy(refEdges = refEdges ++ List(AstEdge(src, dst))) + + def withBindsEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.BINDS) + this.copy(bindsEdges = bindsEdges ++ List(AstEdge(src, dst))) + + def withReceiverEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER) + this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) + + def withArgEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) + this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) + + def withArgEdges(src: NewNode, dsts: Seq[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT)) + this.copy(argEdges = argEdges ++ dsts.map(AstEdge(src, _))) + + def withArgEdges(src: NewNode, dsts: Seq[NewNode], argIndexStart: Int): Ast = + var index = argIndexStart + this.copy(argEdges = argEdges ++ dsts.map { dst => + addArgumentIndex(dst, index) + index += 1 + Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) + AstEdge(src, dst) + }) + private def addArgumentIndex(node: NewNode, argIndex: Int): Unit = node match + case n: NewBlock => n.argumentIndex = argIndex + case n: NewCall => n.argumentIndex = argIndex + case n: NewFieldIdentifier => n.argumentIndex = argIndex + case n: NewIdentifier => n.argumentIndex = argIndex + case n: NewMethodRef => n.argumentIndex = argIndex + case n: NewTypeRef => n.argumentIndex = argIndex + case n: NewUnknown => n.argumentIndex = argIndex + case n: NewControlStructure => n.argumentIndex = argIndex + case n: NewLiteral => n.argumentIndex = argIndex + case n: NewReturn => n.argumentIndex = argIndex + + def withConditionEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION)) + this.copy(conditionEdges = conditionEdges ++ dsts.map(AstEdge(src, _))) + + def withRefEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.REF)) + this.copy(refEdges = refEdges ++ dsts.map(AstEdge(src, _))) + + def withBindsEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.BINDS)) + this.copy(bindsEdges = bindsEdges ++ dsts.map(AstEdge(src, _))) + + def withReceiverEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER)) + this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) + + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` + * and `argumentIndex` fields of the new root node are set to `order`. If `replacementNode` is + * set, then this replaces `node` in the new copy. + */ + def subTreeCopy( + node: AstNodeNew, + argIndex: Int = -1, + replacementNode: Option[AstNodeNew] = None + ): Ast = + val newNode = replacementNode match + case Some(n) => n + case None => node.copy + if argIndex != -1 then + // newNode.order = argIndex + newNode match + case expr: ExpressionNew => + expr.argumentIndex = argIndex + case _ => + + val astChildren = edges.filter(_.src == node).map(_.dst) + val newChildren = astChildren.map { x => + this.subTreeCopy(x.asInstanceOf[AstNodeNew]) } - ), - conditionEdges = conditionEdges ++ other.conditionEdges, - argEdges = argEdges ++ other.argEdges, - receiverEdges = receiverEdges ++ other.receiverEdges, - refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges - ) - } - - def merge(other: Ast): Ast = { - Ast( - nodes ++ other.nodes, - edges = edges ++ other.edges, - conditionEdges = conditionEdges ++ other.conditionEdges, - argEdges = argEdges ++ other.argEdges, - receiverEdges = receiverEdges ++ other.receiverEdges, - refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges - ) - } - - /** AST that results when adding all ASTs in `asts` as children, that is, connecting them to the root node of this - * AST. - */ - def withChildren(asts: collection.Seq[Ast]): Ast = { - if (asts.isEmpty) { - this - } else { - // we do this iteratively as a recursive solution which will fail with - // a StackOverflowException if there are too many elements in .tail. - var ast = withChild(asts.head) - asts.tail.foreach(c => ast = ast.withChild(c)) - ast - } - } - - def withConditionEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION) - this.copy(conditionEdges = conditionEdges ++ List(AstEdge(src, dst))) - } - - def withRefEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.REF) - this.copy(refEdges = refEdges ++ List(AstEdge(src, dst))) - } - - def withBindsEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.BINDS) - this.copy(bindsEdges = bindsEdges ++ List(AstEdge(src, dst))) - } - - def withReceiverEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER) - this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) - } - - def withArgEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) - this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) - } - - def withArgEdges(src: NewNode, dsts: Seq[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT)) - this.copy(argEdges = argEdges ++ dsts.map(AstEdge(src, _))) - } - - def withArgEdges(src: NewNode, dsts: Seq[NewNode], argIndexStart: Int): Ast = { - var index = argIndexStart - this.copy(argEdges = argEdges ++ dsts.map { dst => - addArgumentIndex(dst, index) - index += 1 - Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) - AstEdge(src, dst) - }) - } - private def addArgumentIndex(node: NewNode, argIndex: Int): Unit = node match { - case n: NewBlock => n.argumentIndex = argIndex - case n: NewCall => n.argumentIndex = argIndex - case n: NewFieldIdentifier => n.argumentIndex = argIndex - case n: NewIdentifier => n.argumentIndex = argIndex - case n: NewMethodRef => n.argumentIndex = argIndex - case n: NewTypeRef => n.argumentIndex = argIndex - case n: NewUnknown => n.argumentIndex = argIndex - case n: NewControlStructure => n.argumentIndex = argIndex - case n: NewLiteral => n.argumentIndex = argIndex - case n: NewReturn => n.argumentIndex = argIndex - } - - def withConditionEdges(src: NewNode, dsts: List[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION)) - this.copy(conditionEdges = conditionEdges ++ dsts.map(AstEdge(src, _))) - } - - def withRefEdges(src: NewNode, dsts: List[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.REF)) - this.copy(refEdges = refEdges ++ dsts.map(AstEdge(src, _))) - } - - def withBindsEdges(src: NewNode, dsts: List[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.BINDS)) - this.copy(bindsEdges = bindsEdges ++ dsts.map(AstEdge(src, _))) - } - - def withReceiverEdges(src: NewNode, dsts: List[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER)) - this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) - } - - /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` - * fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new - * copy. - */ - def subTreeCopy(node: AstNodeNew, argIndex: Int = -1, replacementNode: Option[AstNodeNew] = None): Ast = { - val newNode = replacementNode match - case Some(n) => n - case None => node.copy - if (argIndex != -1) { - // newNode.order = argIndex - newNode match { - case expr: ExpressionNew => - expr.argumentIndex = argIndex - case _ => - } - } - - val astChildren = edges.filter(_.src == node).map(_.dst) - val newChildren = astChildren.map { x => - this.subTreeCopy(x.asInstanceOf[AstNodeNew]) - } - - val oldToNew = astChildren.zip(newChildren).map { case (old, n) => old -> n.root.get }.toMap - def newIfExists(x: NewNode) = { - oldToNew.getOrElse(x, x) - } - - val newArgEdges = argEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newConditionEdges = conditionEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - - Ast(newNode) - .copy( - argEdges = newArgEdges, - conditionEdges = newConditionEdges, - refEdges = newRefEdges, - bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges - ) - .withChildren(newChildren) - } - -} + + val oldToNew = astChildren.zip(newChildren).map { case (old, n) => old -> n.root.get }.toMap + def newIfExists(x: NewNode) = + oldToNew.getOrElse(x, x) + + val newArgEdges = + argEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newConditionEdges = + conditionEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newRefEdges = + refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newBindsEdges = + bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newReceiverEdges = + receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + + Ast(newNode) + .copy( + argEdges = newArgEdges, + conditionEdges = newConditionEdges, + refEdges = newRefEdges, + bindsEdges = newBindsEdges, + receiverEdges = newReceiverEdges + ) + .withChildren(newChildren) + end subTreeCopy +end Ast diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala index 2cc93edd..e26454c5 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala @@ -2,333 +2,328 @@ package io.appthreat.x2cpg import io.appthreat.x2cpg.passes.frontend.MetaDataPass import io.appthreat.x2cpg.utils.NodeBuilders.newMethodReturnNode -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, ModifierTypes} import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import overflowdb.BatchedUpdate.DiffGraphBuilder -abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: ValidationMode) { - val diffGraph: DiffGraphBuilder = new DiffGraphBuilder - - def createAst(): DiffGraphBuilder - - /** Create a global namespace block for the given `filename` - */ - def globalNamespaceBlock(): NewNamespaceBlock = { - val name = NamespaceTraversal.globalNamespaceName - val fullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) - NewNamespaceBlock() - .name(name) - .fullName(fullName) - .filename(filename) - .order(1) - } - - /** Creates an AST that represents an annotation, including its content (annotation parameter assignments). - */ - def annotationAst(annotation: NewAnnotation, children: Seq[Ast]): Ast = { - val annotationAst = Ast(annotation) - annotationAst.withChildren(children) - } - - /** Creates an AST that represents an annotation assignment with a name for the assigned value, its overall code, and - * the respective assignment AST. - */ - def annotationAssignmentAst(assignmentValueName: String, code: String, assignmentAst: Ast): Ast = { - val parameter = NewAnnotationParameter().code(assignmentValueName) - val assign = NewAnnotationParameterAssign().code(code) - val assignChildren = List(Ast(parameter), assignmentAst) - setArgumentIndices(assignChildren) - Ast(assign) - .withChild(Ast(parameter)) - .withChild(assignmentAst) - } - - /** Creates an AST that represents an entire method, including its content. - */ - def methodAst( - method: NewMethod, - parameters: Seq[Ast], - body: Ast, - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil - ): Ast = - methodAstWithAnnotations(method, parameters, body, methodReturn, modifiers, annotations = Nil) - - /** Creates an AST that represents an entire method, including its content and with support for both method and - * parameter annotations. - */ - def methodAstWithAnnotations( - method: NewMethod, - parameters: Seq[Ast], - body: Ast, - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil, - annotations: Seq[Ast] = Nil - ): Ast = - Ast(method) - .withChildren(parameters) - .withChild(body) - .withChildren(modifiers.map(Ast(_))) - .withChildren(annotations) - .withChild(Ast(methodReturn)) - - /** Creates an AST that represents a method stub, containing information about the method, its parameters, and the - * return type. - */ - def methodStubAst( - method: NewMethod, - parameters: Seq[NewMethodParameterIn], - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil - ): Ast = - Ast(method) - .withChildren(parameters.map(Ast(_))) - .withChild(Ast(NewBlock())) - .withChildren(modifiers.map(Ast(_))) - .withChild(Ast(methodReturn)) - - def staticInitMethodAst( - initAsts: List[Ast], - fullName: String, - signature: Option[String], - returnType: String, - fileName: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = { - val methodNode = NewMethod() - .name(Defines.StaticInitMethodName) - .fullName(fullName) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if (signature.isDefined) { - methodNode.signature(signature.get) - } - if (fileName.isDefined) { - methodNode.filename(fileName.get) - } - val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC) - val body = blockAst(NewBlock(), initAsts) - val methodReturn = newMethodReturnNode(returnType, None, None, None) - methodAst(methodNode, Nil, body, methodReturn, List(staticModifier)) - } - - /** For a given return node and arguments, create an AST that represents the return instruction. The main purpose of - * this method is to automatically assign the correct argument indices. - */ - def returnAst(returnNode: NewReturn, arguments: Seq[Ast] = List()): Ast = { - setArgumentIndices(arguments) - Ast(returnNode) - .withChildren(arguments) - .withArgEdges(returnNode, arguments.flatMap(_.root)) - } - - /** For a given node, condition AST and children ASTs, create an AST that represents the control structure. The main - * purpose of this method is to automatically assign the correct condition edges. - */ - def controlStructureAst( - controlStructureNode: NewControlStructure, - condition: Option[Ast], - children: Seq[Ast] = Seq(), - placeConditionLast: Boolean = false - ): Ast = { - condition match { - case Some(conditionAst) => - Ast(controlStructureNode) - .withChildren(if (placeConditionLast) children :+ conditionAst else conditionAst +: children) - .withConditionEdges(controlStructureNode, List(conditionAst.root).flatten) - case _ => - Ast(controlStructureNode) - .withChildren(children) - } - } - - def wrapMultipleInBlock(asts: Seq[Ast], lineNumber: Option[Integer]): Ast = { - asts.toList match { - case Nil => blockAst(NewBlock().lineNumber(lineNumber)) - case ast :: Nil => ast - case astList => blockAst(NewBlock().lineNumber(lineNumber), astList) - } - } - - def whileAst( - condition: Option[Ast], - body: Seq[Ast], - code: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = { - var whileNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.WHILE) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if (code.isDefined) { - whileNode = whileNode.code(code.get) - } - controlStructureAst(whileNode, condition, body) - } - - def doWhileAst( - condition: Option[Ast], - body: Seq[Ast], - code: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = { - var doWhileNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.DO) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if (code.isDefined) { - doWhileNode = doWhileNode.code(code.get) - } - controlStructureAst(doWhileNode, condition, body, placeConditionLast = true) - } - - def forAst( - forNode: NewControlStructure, - locals: Seq[Ast], - initAsts: Seq[Ast], - conditionAsts: Seq[Ast], - updateAsts: Seq[Ast], - bodyAst: Ast - ): Ast = - forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst)) - - def forAst( - forNode: NewControlStructure, - locals: Seq[Ast], - initAsts: Seq[Ast], - conditionAsts: Seq[Ast], - updateAsts: Seq[Ast], - bodyAsts: Seq[Ast] - ): Ast = { - val lineNumber = forNode.lineNumber - Ast(forNode) - .withChildren(locals) - .withChild(wrapMultipleInBlock(initAsts, lineNumber)) - .withChild(wrapMultipleInBlock(conditionAsts, lineNumber)) - .withChild(wrapMultipleInBlock(updateAsts, lineNumber)) - .withChildren(bodyAsts) - .withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList) - } - - /** For the given try body, catch ASTs and finally AST, create a try-catch-finally AST with orders set correctly for - * the ossdataflow engine. - */ - def tryCatchAst(tryNode: NewControlStructure, tryBodyAst: Ast, catchAsts: Seq[Ast], finallyAst: Option[Ast]): Ast = { - tryBodyAst.root.collect { case x: ExpressionNew => x }.foreach(_.order = 1) - catchAsts.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 2) - finallyAst.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 3) - Ast(tryNode) - .withChild(tryBodyAst) - .withChildren(catchAsts) - .withChildren(finallyAst.toList) - } - - /** For a given block node and statement ASTs, create an AST that represents the block. The main purpose of this - * method is to increase the readability of the code which creates block asts. - */ - def blockAst(blockNode: NewBlock, statements: List[Ast] = List()): Ast = { - Ast(blockNode).withChildren(statements) - } - - /** Create an abstract syntax tree for a call, including CPG-specific edges required for arguments and the receiver. - * - * Our call representation is inspired by ECMAScript, that is, in addition to arguments, a call has a base and a - * receiver. For languages other than Javascript, leave `receiver` empty for now. - * - * @param callNode - * the node that represents the entire call - * @param arguments - * arguments (without the base argument (instance)) - * @param base - * the value to use as `this` in the method call. - * @param receiver - * the object in which the property lookup is performed - */ - def callAst( - callNode: NewCall, - arguments: Seq[Ast] = List(), - base: Option[Ast] = None, - receiver: Option[Ast] = None - ): Ast = { - - setArgumentIndices(arguments) - - val baseRoot = base.flatMap(_.root).toList - val bse = base.getOrElse(Ast()) - baseRoot match { - case List(x: ExpressionNew) => - x.argumentIndex = 0 - case _ => - } - - val receiverRoot = if (receiver.isEmpty && base.nonEmpty) { - baseRoot - } else { - val r = receiver.flatMap(_.root).toList - r match { - case List(x: ExpressionNew) => - x.argumentIndex = -1 - case _ => - } - r - } - - val rcvAst = receiver.getOrElse(Ast()) - - Ast(callNode) - .withChild(rcvAst) - .withChild(bse) - .withChildren(arguments) - .withArgEdges(callNode, baseRoot) - .withArgEdges(callNode, arguments.flatMap(_.root)) - .withReceiverEdges(callNode, receiverRoot) - } - - def setArgumentIndices(arguments: Seq[Ast]): Unit = { - var currIndex = 1 - arguments.foreach { a => - a.root match { - case Some(x: ExpressionNew) => - x.argumentIndex = currIndex - currIndex = currIndex + 1 - case None => // do nothing - case _ => - currIndex = currIndex + 1 - } - } - } - - def withIndex[T, X](nodes: Seq[T])(f: (T, Int) => X): Seq[X] = - nodes.zipWithIndex.map { case (x, i) => - f(x, i + 1) - } - - def withIndex[T, X](nodes: Array[T])(f: (T, Int) => X): Seq[X] = - nodes.toIndexedSeq.zipWithIndex.map { case (x, i) => - f(x, i + 1) - } - - def withArgumentIndex[T <: ExpressionNew](node: T, argIdxOpt: Option[Int]): T = { - argIdxOpt match { - case Some(argIdx) => - node.argumentIndex = argIdx +abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: ValidationMode): + val diffGraph: DiffGraphBuilder = new DiffGraphBuilder + + def createAst(): DiffGraphBuilder + + /** Create a global namespace block for the given `filename` + */ + def globalNamespaceBlock(): NewNamespaceBlock = + val name = NamespaceTraversal.globalNamespaceName + val fullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) + NewNamespaceBlock() + .name(name) + .fullName(fullName) + .filename(filename) + .order(1) + + /** Creates an AST that represents an annotation, including its content (annotation parameter + * assignments). + */ + def annotationAst(annotation: NewAnnotation, children: Seq[Ast]): Ast = + val annotationAst = Ast(annotation) + annotationAst.withChildren(children) + + /** Creates an AST that represents an annotation assignment with a name for the assigned value, + * its overall code, and the respective assignment AST. + */ + def annotationAssignmentAst( + assignmentValueName: String, + code: String, + assignmentAst: Ast + ): Ast = + val parameter = NewAnnotationParameter().code(assignmentValueName) + val assign = NewAnnotationParameterAssign().code(code) + val assignChildren = List(Ast(parameter), assignmentAst) + setArgumentIndices(assignChildren) + Ast(assign) + .withChild(Ast(parameter)) + .withChild(assignmentAst) + + /** Creates an AST that represents an entire method, including its content. + */ + def methodAst( + method: NewMethod, + parameters: Seq[Ast], + body: Ast, + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil + ): Ast = + methodAstWithAnnotations( + method, + parameters, + body, + methodReturn, + modifiers, + annotations = Nil + ) + + /** Creates an AST that represents an entire method, including its content and with support for + * both method and parameter annotations. + */ + def methodAstWithAnnotations( + method: NewMethod, + parameters: Seq[Ast], + body: Ast, + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil, + annotations: Seq[Ast] = Nil + ): Ast = + Ast(method) + .withChildren(parameters) + .withChild(body) + .withChildren(modifiers.map(Ast(_))) + .withChildren(annotations) + .withChild(Ast(methodReturn)) + + /** Creates an AST that represents a method stub, containing information about the method, its + * parameters, and the return type. + */ + def methodStubAst( + method: NewMethod, + parameters: Seq[NewMethodParameterIn], + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil + ): Ast = + Ast(method) + .withChildren(parameters.map(Ast(_))) + .withChild(Ast(NewBlock())) + .withChildren(modifiers.map(Ast(_))) + .withChild(Ast(methodReturn)) + + def staticInitMethodAst( + initAsts: List[Ast], + fullName: String, + signature: Option[String], + returnType: String, + fileName: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + val methodNode = NewMethod() + .name(Defines.StaticInitMethodName) + .fullName(fullName) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if signature.isDefined then + methodNode.signature(signature.get) + if fileName.isDefined then + methodNode.filename(fileName.get) + val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC) + val body = blockAst(NewBlock(), initAsts) + val methodReturn = newMethodReturnNode(returnType, None, None, None) + methodAst(methodNode, Nil, body, methodReturn, List(staticModifier)) + end staticInitMethodAst + + /** For a given return node and arguments, create an AST that represents the return instruction. + * The main purpose of this method is to automatically assign the correct argument indices. + */ + def returnAst(returnNode: NewReturn, arguments: Seq[Ast] = List()): Ast = + setArgumentIndices(arguments) + Ast(returnNode) + .withChildren(arguments) + .withArgEdges(returnNode, arguments.flatMap(_.root)) + + /** For a given node, condition AST and children ASTs, create an AST that represents the control + * structure. The main purpose of this method is to automatically assign the correct condition + * edges. + */ + def controlStructureAst( + controlStructureNode: NewControlStructure, + condition: Option[Ast], + children: Seq[Ast] = Seq(), + placeConditionLast: Boolean = false + ): Ast = + condition match + case Some(conditionAst) => + Ast(controlStructureNode) + .withChildren(if placeConditionLast then children :+ conditionAst + else conditionAst +: children) + .withConditionEdges(controlStructureNode, List(conditionAst.root).flatten) + case _ => + Ast(controlStructureNode) + .withChildren(children) + + def wrapMultipleInBlock(asts: Seq[Ast], lineNumber: Option[Integer]): Ast = + asts.toList match + case Nil => blockAst(NewBlock().lineNumber(lineNumber)) + case ast :: Nil => ast + case astList => blockAst(NewBlock().lineNumber(lineNumber), astList) + + def whileAst( + condition: Option[Ast], + body: Seq[Ast], + code: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + var whileNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.WHILE) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if code.isDefined then + whileNode = whileNode.code(code.get) + controlStructureAst(whileNode, condition, body) + + def doWhileAst( + condition: Option[Ast], + body: Seq[Ast], + code: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + var doWhileNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.DO) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if code.isDefined then + doWhileNode = doWhileNode.code(code.get) + controlStructureAst(doWhileNode, condition, body, placeConditionLast = true) + + def forAst( + forNode: NewControlStructure, + locals: Seq[Ast], + initAsts: Seq[Ast], + conditionAsts: Seq[Ast], + updateAsts: Seq[Ast], + bodyAst: Ast + ): Ast = + forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst)) + + def forAst( + forNode: NewControlStructure, + locals: Seq[Ast], + initAsts: Seq[Ast], + conditionAsts: Seq[Ast], + updateAsts: Seq[Ast], + bodyAsts: Seq[Ast] + ): Ast = + val lineNumber = forNode.lineNumber + Ast(forNode) + .withChildren(locals) + .withChild(wrapMultipleInBlock(initAsts, lineNumber)) + .withChild(wrapMultipleInBlock(conditionAsts, lineNumber)) + .withChild(wrapMultipleInBlock(updateAsts, lineNumber)) + .withChildren(bodyAsts) + .withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList) + + /** For the given try body, catch ASTs and finally AST, create a try-catch-finally AST with + * orders set correctly for the ossdataflow engine. + */ + def tryCatchAst( + tryNode: NewControlStructure, + tryBodyAst: Ast, + catchAsts: Seq[Ast], + finallyAst: Option[Ast] + ): Ast = + tryBodyAst.root.collect { case x: ExpressionNew => x }.foreach(_.order = 1) + catchAsts.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 2) + finallyAst.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 3) + Ast(tryNode) + .withChild(tryBodyAst) + .withChildren(catchAsts) + .withChildren(finallyAst.toList) + + /** For a given block node and statement ASTs, create an AST that represents the block. The main + * purpose of this method is to increase the readability of the code which creates block asts. + */ + def blockAst(blockNode: NewBlock, statements: List[Ast] = List()): Ast = + Ast(blockNode).withChildren(statements) + + /** Create an abstract syntax tree for a call, including CPG-specific edges required for + * arguments and the receiver. + * + * Our call representation is inspired by ECMAScript, that is, in addition to arguments, a call + * has a base and a receiver. For languages other than Javascript, leave `receiver` empty for + * now. + * + * @param callNode + * the node that represents the entire call + * @param arguments + * arguments (without the base argument (instance)) + * @param base + * the value to use as `this` in the method call. + * @param receiver + * the object in which the property lookup is performed + */ + def callAst( + callNode: NewCall, + arguments: Seq[Ast] = List(), + base: Option[Ast] = None, + receiver: Option[Ast] = None + ): Ast = + + setArgumentIndices(arguments) + + val baseRoot = base.flatMap(_.root).toList + val bse = base.getOrElse(Ast()) + baseRoot match + case List(x: ExpressionNew) => + x.argumentIndex = 0 + case _ => + + val receiverRoot = if receiver.isEmpty && base.nonEmpty then + baseRoot + else + val r = receiver.flatMap(_.root).toList + r match + case List(x: ExpressionNew) => + x.argumentIndex = -1 + case _ => + r + + val rcvAst = receiver.getOrElse(Ast()) + + Ast(callNode) + .withChild(rcvAst) + .withChild(bse) + .withChildren(arguments) + .withArgEdges(callNode, baseRoot) + .withArgEdges(callNode, arguments.flatMap(_.root)) + .withReceiverEdges(callNode, receiverRoot) + end callAst + + def setArgumentIndices(arguments: Seq[Ast]): Unit = + var currIndex = 1 + arguments.foreach { a => + a.root match + case Some(x: ExpressionNew) => + x.argumentIndex = currIndex + currIndex = currIndex + 1 + case None => // do nothing + case _ => + currIndex = currIndex + 1 + } + + def withIndex[T, X](nodes: Seq[T])(f: (T, Int) => X): Seq[X] = + nodes.zipWithIndex.map { case (x, i) => + f(x, i + 1) + } + + def withIndex[T, X](nodes: Array[T])(f: (T, Int) => X): Seq[X] = + nodes.toIndexedSeq.zipWithIndex.map { case (x, i) => + f(x, i + 1) + } + + def withArgumentIndex[T <: ExpressionNew](node: T, argIdxOpt: Option[Int]): T = + argIdxOpt match + case Some(argIdx) => + node.argumentIndex = argIdx + node + case None => node + + def withArgumentName[T <: ExpressionNew](node: T, argNameOpt: Option[String]): T = + node.argumentName = argNameOpt node - case None => node - } - } - - def withArgumentName[T <: ExpressionNew](node: T, argNameOpt: Option[String]): T = { - node.argumentName = argNameOpt - node - } - - /** Absolute path for the given file name - */ - def absolutePath(filename: String): String = - better.files.File(filename).path.toAbsolutePath.normalize().toString -} + /** Absolute path for the given file name + */ + def absolutePath(filename: String): String = + better.files.File(filename).path.toAbsolutePath.normalize().toString +end AstCreatorBase diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala index 39ba4c45..f81ba4f0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala @@ -1,316 +1,335 @@ package io.appthreat.x2cpg import io.shiftleft.codepropertygraph.generated.nodes.{ - NewAnnotation, - NewBlock, - NewCall, - NewControlStructure, - NewFieldIdentifier, - NewIdentifier, - NewImport, - NewJumpTarget, - NewLiteral, - NewLocal, - NewMember, - NewMethod, - NewMethodParameterIn, - NewMethodRef, - NewMethodReturn, - NewReturn, - NewTypeDecl, - NewTypeRef, - NewUnknown + NewAnnotation, + NewBlock, + NewCall, + NewControlStructure, + NewFieldIdentifier, + NewIdentifier, + NewImport, + NewJumpTarget, + NewLiteral, + NewLocal, + NewMember, + NewMethod, + NewMethodParameterIn, + NewMethodRef, + NewMethodReturn, + NewReturn, + NewTypeDecl, + NewTypeRef, + NewUnknown } -import io.shiftleft.codepropertygraph.generated.nodes.Block.{PropertyDefaults => BlockDefaults} +import io.shiftleft.codepropertygraph.generated.nodes.Block.{PropertyDefaults as BlockDefaults} import org.apache.commons.lang.StringUtils import io.appthreat.x2cpg.utils.NodeBuilders.newMethodReturnNode -trait AstNodeBuilder[Node, NodeProcessor] { this: NodeProcessor => - protected def line(node: Node): Option[Integer] - protected def column(node: Node): Option[Integer] - protected def lineEnd(node: Node): Option[Integer] - protected def columnEnd(element: Node): Option[Integer] +trait AstNodeBuilder[Node, NodeProcessor]: + this: NodeProcessor => + protected def line(node: Node): Option[Integer] + protected def column(node: Node): Option[Integer] + protected def lineEnd(node: Node): Option[Integer] + protected def columnEnd(element: Node): Option[Integer] - protected def unknownNode(node: Node, code: String): NewUnknown = { - NewUnknown() - .parserTypeName(node.getClass.getSimpleName) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def unknownNode(node: Node, code: String): NewUnknown = + NewUnknown() + .parserTypeName(node.getClass.getSimpleName) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def annotationNode(node: Node, code: String, name: String, fullName: String): NewAnnotation = { - NewAnnotation() - .code(code) - .name(name) - .fullName(fullName) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def annotationNode( + node: Node, + code: String, + name: String, + fullName: String + ): NewAnnotation = + NewAnnotation() + .code(code) + .name(name) + .fullName(fullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def methodRefNode(node: Node, code: String, methodFullName: String, typeFullName: String): NewMethodRef = { - NewMethodRef() - .code(code) - .methodFullName(methodFullName) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def methodRefNode( + node: Node, + code: String, + methodFullName: String, + typeFullName: String + ): NewMethodRef = + NewMethodRef() + .code(code) + .methodFullName(methodFullName) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def memberNode(node: Node, name: String, code: String, typeFullName: String): NewMember = - memberNode(node, name, code, typeFullName, Seq()) + protected def memberNode( + node: Node, + name: String, + code: String, + typeFullName: String + ): NewMember = + memberNode(node, name, code, typeFullName, Seq()) - protected def memberNode( - node: Node, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewMember = { - NewMember() - .code(code) - .name(name) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def memberNode( + node: Node, + name: String, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewMember = + NewMember() + .code(code) + .name(name) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def newImportNode(code: String, importedEntity: String, importedAs: String, include: Node): NewImport = { - NewImport() - .code(code) - .importedEntity(importedEntity) - .importedAs(importedAs) - .lineNumber(line(include)) - .columnNumber(column(include)) - } + protected def newImportNode( + code: String, + importedEntity: String, + importedAs: String, + include: Node + ): NewImport = + NewImport() + .code(code) + .importedEntity(importedEntity) + .importedAs(importedAs) + .lineNumber(line(include)) + .columnNumber(column(include)) - protected def literalNode( - node: Node, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewLiteral = { - NewLiteral() - .code(code) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def literalNode( + node: Node, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewLiteral = + NewLiteral() + .code(code) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def typeRefNode(node: Node, code: String, typeFullName: String): NewTypeRef = { - NewTypeRef() - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def typeRefNode(node: Node, code: String, typeFullName: String): NewTypeRef = + NewTypeRef() + .code(code) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - def typeDeclNode( - node: Node, - name: String, - fullName: String, - fileName: String, - inheritsFrom: Seq[String], - alias: Option[String] - ): NewTypeDecl = - typeDeclNode(node, name, fullName, fileName, name, "", "", inheritsFrom, alias) + def typeDeclNode( + node: Node, + name: String, + fullName: String, + fileName: String, + inheritsFrom: Seq[String], + alias: Option[String] + ): NewTypeDecl = + typeDeclNode(node, name, fullName, fileName, name, "", "", inheritsFrom, alias) - protected def typeDeclNode( - node: Node, - name: String, - fullName: String, - filename: String, - code: String, - astParentType: String = "", - astParentFullName: String = "", - inherits: Seq[String] = Seq.empty, - alias: Option[String] = None - ): NewTypeDecl = { - NewTypeDecl() - .name(name) - .fullName(fullName) - .code(code) - .isExternal(false) - .filename(filename) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .inheritsFromTypeFullName(inherits) - .aliasTypeFullName(alias) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def typeDeclNode( + node: Node, + name: String, + fullName: String, + filename: String, + code: String, + astParentType: String = "", + astParentFullName: String = "", + inherits: Seq[String] = Seq.empty, + alias: Option[String] = None + ): NewTypeDecl = + NewTypeDecl() + .name(name) + .fullName(fullName) + .code(code) + .isExternal(false) + .filename(filename) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .inheritsFromTypeFullName(inherits) + .aliasTypeFullName(alias) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def parameterInNode( - node: Node, - name: String, - code: String, - index: Int, - isVariadic: Boolean, - evaluationStrategy: String, - typeFullName: String - ): NewMethodParameterIn = - parameterInNode(node, name, code, index, isVariadic, evaluationStrategy, Some(typeFullName)) + protected def parameterInNode( + node: Node, + name: String, + code: String, + index: Int, + isVariadic: Boolean, + evaluationStrategy: String, + typeFullName: String + ): NewMethodParameterIn = + parameterInNode(node, name, code, index, isVariadic, evaluationStrategy, Some(typeFullName)) - protected def parameterInNode( - node: Node, - name: String, - code: String, - index: Int, - isVariadic: Boolean, - evaluationStrategy: String, - typeFullName: Option[String] = None - ): NewMethodParameterIn = { - NewMethodParameterIn() - .name(name) - .code(code) - .index(index) - .order(index) - .isVariadic(isVariadic) - .evaluationStrategy(evaluationStrategy) - .lineNumber(line(node)) - .columnNumber(column(node)) - .typeFullName(typeFullName.getOrElse("ANY")) - } + protected def parameterInNode( + node: Node, + name: String, + code: String, + index: Int, + isVariadic: Boolean, + evaluationStrategy: String, + typeFullName: Option[String] = None + ): NewMethodParameterIn = + NewMethodParameterIn() + .name(name) + .code(code) + .index(index) + .order(index) + .isVariadic(isVariadic) + .evaluationStrategy(evaluationStrategy) + .lineNumber(line(node)) + .columnNumber(column(node)) + .typeFullName(typeFullName.getOrElse("ANY")) - def callNode(node: Node, code: String, name: String, methodFullName: String, dispatchType: String): NewCall = - callNode(node, code, name, methodFullName, dispatchType, None, None) + def callNode( + node: Node, + code: String, + name: String, + methodFullName: String, + dispatchType: String + ): NewCall = + callNode(node, code, name, methodFullName, dispatchType, None, None) - def callNode( - node: Node, - code: String, - name: String, - methodFullName: String, - dispatchType: String, - signature: Option[String], - typeFullName: Option[String] - ): NewCall = { - val out = - NewCall() - .code(code) - .name(name) - .methodFullName(methodFullName) - .dispatchType(dispatchType) - .lineNumber(line(node)) - .columnNumber(column(node)) - signature.foreach { s => out.signature(s) } - typeFullName.foreach { t => out.typeFullName(t) } - out - } + def callNode( + node: Node, + code: String, + name: String, + methodFullName: String, + dispatchType: String, + signature: Option[String], + typeFullName: Option[String] + ): NewCall = + val out = + NewCall() + .code(code) + .name(name) + .methodFullName(methodFullName) + .dispatchType(dispatchType) + .lineNumber(line(node)) + .columnNumber(column(node)) + signature.foreach { s => out.signature(s) } + typeFullName.foreach { t => out.typeFullName(t) } + out + end callNode - protected def returnNode(node: Node, code: String): NewReturn = { - NewReturn() - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def returnNode(node: Node, code: String): NewReturn = + NewReturn() + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def controlStructureNode(node: Node, controlStructureType: String, code: String): NewControlStructure = { - NewControlStructure() - .parserTypeName(node.getClass.getSimpleName) - .controlStructureType(controlStructureType) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def controlStructureNode( + node: Node, + controlStructureType: String, + code: String + ): NewControlStructure = + NewControlStructure() + .parserTypeName(node.getClass.getSimpleName) + .controlStructureType(controlStructureType) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def blockNode(node: Node): NewBlock = { - blockNode(node, BlockDefaults.Code, BlockDefaults.TypeFullName) - } + protected def blockNode(node: Node): NewBlock = + blockNode(node, BlockDefaults.Code, BlockDefaults.TypeFullName) - protected def blockNode(node: Node, code: String, typeFullName: String): NewBlock = { - NewBlock() - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def blockNode(node: Node, code: String, typeFullName: String): NewBlock = + NewBlock() + .code(code) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def fieldIdentifierNode(node: Node, name: String, code: String): NewFieldIdentifier = { - NewFieldIdentifier() - .canonicalName(name) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def fieldIdentifierNode(node: Node, name: String, code: String): NewFieldIdentifier = + NewFieldIdentifier() + .canonicalName(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def localNode( - node: Node, - name: String, - code: String, - typeFullName: String, - closureBindingId: Option[String] = None - ): NewLocal = - NewLocal() - .name(name) - .code(code) - .typeFullName(typeFullName) - .closureBindingId(closureBindingId) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def localNode( + node: Node, + name: String, + code: String, + typeFullName: String, + closureBindingId: Option[String] = None + ): NewLocal = + NewLocal() + .name(name) + .code(code) + .typeFullName(typeFullName) + .closureBindingId(closureBindingId) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def identifierNode( - node: Node, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewIdentifier = { - NewIdentifier() - .name(name) - .typeFullName(typeFullName) - .code(code) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) - } + protected def identifierNode( + node: Node, + name: String, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewIdentifier = + NewIdentifier() + .name(name) + .typeFullName(typeFullName) + .code(code) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - def methodNode(node: Node, name: String, fullName: String, signature: String, fileName: String): NewMethod = { - methodNode(node, name, name, fullName, Some(signature), fileName) - } + def methodNode( + node: Node, + name: String, + fullName: String, + signature: String, + fileName: String + ): NewMethod = + methodNode(node, name, name, fullName, Some(signature), fileName) - protected def methodNode( - node: Node, - name: String, - code: String, - fullName: String, - signature: Option[String], - fileName: String, - astParentType: Option[String] = None, - astParentFullName: Option[String] = None - ): NewMethod = { - val node_ = - NewMethod() - .name(StringUtils.normalizeSpace(name)) - .code(code) - .fullName(StringUtils.normalizeSpace(fullName)) - .filename(fileName) - .astParentType(astParentType.getOrElse("")) - .astParentFullName(astParentFullName.getOrElse("")) - .isExternal(false) - .lineNumber(line(node)) - .columnNumber(column(node)) - .lineNumberEnd(lineEnd(node)) - .columnNumberEnd(columnEnd(node)) - signature.foreach { s => node_.signature(StringUtils.normalizeSpace(s)) } - node_ - } + protected def methodNode( + node: Node, + name: String, + code: String, + fullName: String, + signature: Option[String], + fileName: String, + astParentType: Option[String] = None, + astParentFullName: Option[String] = None + ): NewMethod = + val node_ = + NewMethod() + .name(StringUtils.normalizeSpace(name)) + .code(code) + .fullName(StringUtils.normalizeSpace(fullName)) + .filename(fileName) + .astParentType(astParentType.getOrElse("")) + .astParentFullName(astParentFullName.getOrElse("")) + .isExternal(false) + .lineNumber(line(node)) + .columnNumber(column(node)) + .lineNumberEnd(lineEnd(node)) + .columnNumberEnd(columnEnd(node)) + signature.foreach { s => node_.signature(StringUtils.normalizeSpace(s)) } + node_ + end methodNode - protected def methodReturnNode(node: Node, typeFullName: String): NewMethodReturn = { - newMethodReturnNode(typeFullName, None, line(node), column(node)) - } + protected def methodReturnNode(node: Node, typeFullName: String): NewMethodReturn = + newMethodReturnNode(typeFullName, None, line(node), column(node)) - protected def jumpTargetNode( - node: Node, - name: String, - code: String, - parserTypeName: Option[String] = None - ): NewJumpTarget = { - NewJumpTarget() - .parserTypeName(parserTypeName.getOrElse(node.getClass.getSimpleName)) - .name(name) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - } -} + protected def jumpTargetNode( + node: Node, + name: String, + code: String, + parserTypeName: Option[String] = None + ): NewJumpTarget = + NewJumpTarget() + .parserTypeName(parserTypeName.getOrElse(node.getClass.getSimpleName)) + .name(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) +end AstNodeBuilder diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala index eb291522..e628afc5 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala @@ -1,31 +1,31 @@ package io.appthreat.x2cpg -object Defines { - // The following two defines should be used for type and method full names to - // indicate unresolved static type information. Using them enables - // the closed source backend to apply policies in a less strict fashion. - // The most notable case is the METHOD_FULL_NAME property of a CALL node. - // As example consider a call to a method `foo(someArg)` which cannot be - // resolved. The METHOD_FULL_NAME should be given as - // ".foo:(1)". If the namespace is known - // the METHOD_FULL_NAME should be given as - // "some.namespace.foo:(1)". Thereby the number in parenthesis - // is the number of call arguments. - // Note that this schema and thus the defines only makes sense for statically - // typed languages with a package/namespace structure like Java, CSharp, etc.. - val UnresolvedNamespace = "" - val UnresolvedSignature = "" +object Defines: + // The following two defines should be used for type and method full names to + // indicate unresolved static type information. Using them enables + // the closed source backend to apply policies in a less strict fashion. + // The most notable case is the METHOD_FULL_NAME property of a CALL node. + // As example consider a call to a method `foo(someArg)` which cannot be + // resolved. The METHOD_FULL_NAME should be given as + // ".foo:(1)". If the namespace is known + // the METHOD_FULL_NAME should be given as + // "some.namespace.foo:(1)". Thereby the number in parenthesis + // is the number of call arguments. + // Note that this schema and thus the defines only makes sense for statically + // typed languages with a package/namespace structure like Java, CSharp, etc.. + val UnresolvedNamespace = "" + val UnresolvedSignature = "" - // Name of the synthetic, static method that contains the initialization of member variables. - val StaticInitMethodName = "" + // Name of the synthetic, static method that contains the initialization of member variables. + val StaticInitMethodName = "" - // Name of the constructor. - val ConstructorMethodName = "" + // Name of the constructor. + val ConstructorMethodName = "" - // In some languages like Javascript dynamic calls do not provide any statically known - // method/function interface information. In those cases please use this value. - val DynamicCallUnknownFullName = "" + // In some languages like Javascript dynamic calls do not provide any statically known + // method/function interface information. In those cases please use this value. + val DynamicCallUnknownFullName = "" - val LeftAngularBracket = "<" - val Unknown = "" -} + val LeftAngularBracket = "<" + val Unknown = "" +end Defines diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala index 33ea3b5e..de16e86c 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala @@ -4,20 +4,17 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{CallBase, NewImport} import overflowdb.BatchedUpdate.DiffGraphBuilder -object Imports { +object Imports: - def createImportNodeAndLink( - importedEntity: String, - importedAs: String, - call: Option[CallBase], - diffGraph: DiffGraphBuilder - ): NewImport = { - val importNode = NewImport() - .importedAs(importedAs) - .importedEntity(importedEntity) - diffGraph.addNode(importNode) - call.foreach { c => diffGraph.addEdge(c, importNode, EdgeTypes.IS_CALL_FOR_IMPORT) } - importNode - } - -} + def createImportNodeAndLink( + importedEntity: String, + importedAs: String, + call: Option[CallBase], + diffGraph: DiffGraphBuilder + ): NewImport = + val importNode = NewImport() + .importedAs(importedAs) + .importedEntity(importedEntity) + diffGraph.addNode(importNode) + call.foreach { c => diffGraph.addEdge(c, importNode, EdgeTypes.IS_CALL_FOR_IMPORT) } + importNode diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala index db7a73c2..ae8eb912 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala @@ -1,153 +1,149 @@ package io.appthreat.x2cpg import better.files.File.VisitOptions -import better.files._ +import better.files.* import org.slf4j.LoggerFactory import java.io.FileNotFoundException import java.nio.file.Paths -object SourceFiles { - - private val logger = LoggerFactory.getLogger(getClass) - - private def isIgnoredByFileList(filePath: String, config: X2CpgConfig[_]): Boolean = { - val isInIgnoredFiles = config.ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath - } - if (isInIgnoredFiles) { - logger.debug(s"'$filePath' ignored (--exclude)") - true - } else { - false - } - } - - private def isIgnoredByDefault(filePath: String, config: X2CpgConfig[_]): Boolean = { - val relPath = toRelativePath(filePath, config.inputPath) - if (config.defaultIgnoredFilesRegex.exists(_.matches(relPath)) || File(filePath).isSymbolicLink) { - logger.debug(s"'$relPath' ignored by default") - true - } else { - false - } - } - - private def isIgnoredByRegex(filePath: String, config: X2CpgConfig[_]): Boolean = { - val relPath = toRelativePath(filePath, config.inputPath) - val isInIgnoredFilesRegex = config.ignoredFilesRegex.matches(relPath) - if (isInIgnoredFilesRegex) { - logger.debug(s"'$relPath' ignored (--exclude-regex)") - true - } else { - false - } - } - - private def filterFiles(files: List[String], config: X2CpgConfig[_]): List[String] = files.filter { - case filePath if isIgnoredByDefault(filePath, config) => false - case filePath if isIgnoredByFileList(filePath, config) => false - case filePath if isIgnoredByRegex(filePath, config) => false - case _ => true - } - - /** For a given input path, determine all source files by inspecting filename extensions. - */ - def determine(inputPath: String, sourceFileExtensions: Set[String]): List[String] = { - determine(Set(inputPath), sourceFileExtensions) - } - - /** For a given input path, determine all source files by inspecting filename extensions and filter the result - * according to the given config (by its ignoredFilesRegex and ignoredFiles). - */ - def determine(inputPath: String, sourceFileExtensions: Set[String], config: X2CpgConfig[_]): List[String] = { - determine(Set(inputPath), sourceFileExtensions, config) - } - - /** For given input paths, determine all source files by inspecting filename extensions and filter the result - * according to the given config (by its ignoredFilesRegex and ignoredFiles). - */ - def determine(inputPaths: Set[String], sourceFileExtensions: Set[String], config: X2CpgConfig[_]): List[String] = { - filterFiles(determine(inputPaths, sourceFileExtensions), config) - } - - /** For a given array of input paths, determine all source files by inspecting filename extensions. - */ - def determine(inputPaths: Set[String], sourceFileExtensions: Set[String]): List[String] = { - def hasSourceFileExtension(file: File): Boolean = - file.extension.exists(sourceFileExtensions.contains) - - val inputFiles = inputPaths.map(File(_)) - assertAllExist(inputFiles) - - val (dirs, files) = inputFiles.partition(_.isDirectory) - - val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) - val matchingFilesFromDirs = dirs - .flatMap(_.listRecursively(VisitOptions.default)) - .filter(hasSourceFileExtension) - .map(_.pathAsString) - - (matchingFiles ++ matchingFilesFromDirs).toList.sorted - } - - /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution early to avoid - * unexpected and hard-to-debug issues in the results. - */ - private def assertAllExist(files: Set[File]): Unit = { - val (existant, nonExistant) = files.partition(_.isReadable) - val nonReadable = existant.filterNot(_.isReadable) - - if (nonExistant.nonEmpty || nonReadable.nonEmpty) { - logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) - - logErrorWithPaths("Source input paths exist, but are not readable", nonReadable.map(_.canonicalPath)) - - throw FileNotFoundException("Invalid source paths provided") - } - } - - private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = { - val pathsArray = paths.toArray.sorted - - pathsArray.lengthCompare(1) match { - case cmp if cmp < 0 => // pathsArray is empty, so don't log anything - case cmp if cmp == 0 => logger.debug(s"$message: ${paths.head}") - - case cmp => - val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") - logger.debug(errorMessage) - } - } - - /** Constructs an absolute path against rootPath. If the given path is already absolute this path is returned - * unaltered. Otherwise, "rootPath / path" is returned. - */ - def toAbsolutePath(path: String, rootPath: String): String = { - val absolutePath = Paths.get(path) match { - case p if p.isAbsolute => p - case _ if rootPath.endsWith(path) => Paths.get(rootPath) - case p => Paths.get(rootPath, p.toString) - } - absolutePath.normalize().toString - } - - /** Constructs a relative path against rootPath. If the given path is not inside rootPath, path is returned unaltered. - * Otherwise, the path relative to rootPath is returned. - */ - def toRelativePath(path: String, rootPath: String): String = { - if (path.startsWith(rootPath)) { - val absolutePath = Paths.get(path).toAbsolutePath - val projectPath = Paths.get(rootPath).toAbsolutePath - if (absolutePath.compareTo(projectPath) == 0) { - absolutePath.getFileName.toString - } else { - projectPath.relativize(absolutePath).toString - } - } else { - path - } - } - -} +object SourceFiles: + + private val logger = LoggerFactory.getLogger(getClass) + + private def isIgnoredByFileList(filePath: String, config: X2CpgConfig[?]): Boolean = + val isInIgnoredFiles = config.ignoredFiles.exists { + case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) + case ignorePath => filePath == ignorePath + } + if isInIgnoredFiles then + logger.debug(s"'$filePath' ignored (--exclude)") + true + else + false + + private def isIgnoredByDefault(filePath: String, config: X2CpgConfig[?]): Boolean = + val relPath = toRelativePath(filePath, config.inputPath) + if config.defaultIgnoredFilesRegex.exists(_.matches(relPath)) || File( + filePath + ).isSymbolicLink + then + logger.debug(s"'$relPath' ignored by default") + true + else + false + + private def isIgnoredByRegex(filePath: String, config: X2CpgConfig[?]): Boolean = + val relPath = toRelativePath(filePath, config.inputPath) + val isInIgnoredFilesRegex = config.ignoredFilesRegex.matches(relPath) + if isInIgnoredFilesRegex then + logger.debug(s"'$relPath' ignored (--exclude-regex)") + true + else + false + + private def filterFiles(files: List[String], config: X2CpgConfig[?]): List[String] = + files.filter { + case filePath if isIgnoredByDefault(filePath, config) => false + case filePath if isIgnoredByFileList(filePath, config) => false + case filePath if isIgnoredByRegex(filePath, config) => false + case _ => true + } + + /** For a given input path, determine all source files by inspecting filename extensions. + */ + def determine(inputPath: String, sourceFileExtensions: Set[String]): List[String] = + determine(Set(inputPath), sourceFileExtensions) + + /** For a given input path, determine all source files by inspecting filename extensions and + * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). + */ + def determine( + inputPath: String, + sourceFileExtensions: Set[String], + config: X2CpgConfig[?] + ): List[String] = + determine(Set(inputPath), sourceFileExtensions, config) + + /** For given input paths, determine all source files by inspecting filename extensions and + * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). + */ + def determine( + inputPaths: Set[String], + sourceFileExtensions: Set[String], + config: X2CpgConfig[?] + ): List[String] = + filterFiles(determine(inputPaths, sourceFileExtensions), config) + + /** For a given array of input paths, determine all source files by inspecting filename + * extensions. + */ + def determine(inputPaths: Set[String], sourceFileExtensions: Set[String]): List[String] = + def hasSourceFileExtension(file: File): Boolean = + file.extension.exists(sourceFileExtensions.contains) + + val inputFiles = inputPaths.map(File(_)) + assertAllExist(inputFiles) + + val (dirs, files) = inputFiles.partition(_.isDirectory) + + val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) + val matchingFilesFromDirs = dirs + .flatMap(_.listRecursively(VisitOptions.default)) + .filter(hasSourceFileExtension) + .map(_.pathAsString) + + (matchingFiles ++ matchingFilesFromDirs).toList.sorted + + /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution + * early to avoid unexpected and hard-to-debug issues in the results. + */ + private def assertAllExist(files: Set[File]): Unit = + val (existant, nonExistant) = files.partition(_.isReadable) + val nonReadable = existant.filterNot(_.isReadable) + + if nonExistant.nonEmpty || nonReadable.nonEmpty then + logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) + + logErrorWithPaths( + "Source input paths exist, but are not readable", + nonReadable.map(_.canonicalPath) + ) + + throw FileNotFoundException("Invalid source paths provided") + + private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = + val pathsArray = paths.toArray.sorted + + pathsArray.lengthCompare(1) match + case cmp if cmp < 0 => // pathsArray is empty, so don't log anything + case cmp if cmp == 0 => logger.debug(s"$message: ${paths.head}") + + case cmp => + val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") + logger.debug(errorMessage) + + /** Constructs an absolute path against rootPath. If the given path is already absolute this + * path is returned unaltered. Otherwise, "rootPath / path" is returned. + */ + def toAbsolutePath(path: String, rootPath: String): String = + val absolutePath = Paths.get(path) match + case p if p.isAbsolute => p + case _ if rootPath.endsWith(path) => Paths.get(rootPath) + case p => Paths.get(rootPath, p.toString) + absolutePath.normalize().toString + + /** Constructs a relative path against rootPath. If the given path is not inside rootPath, path + * is returned unaltered. Otherwise, the path relative to rootPath is returned. + */ + def toRelativePath(path: String, rootPath: String): String = + if path.startsWith(rootPath) then + val absolutePath = Paths.get(path).toAbsolutePath + val projectPath = Paths.get(rootPath).toAbsolutePath + if absolutePath.compareTo(projectPath) == 0 then + absolutePath.getFileName.toString + else + projectPath.relativize(absolutePath).toString + else + path +end SourceFiles diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala index 845d48fa..d7c06b47 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala @@ -15,307 +15,294 @@ import java.nio.file.{Files, Paths} import scala.util.matching.Regex import scala.util.{Failure, Success, Try} -object X2CpgConfig { - def defaultOutputPath: String = "cpg.bin" -} - -trait X2CpgConfig[R <: X2CpgConfig[R]] { - var inputPath: String = "" - var outputPath: String = X2CpgConfig.defaultOutputPath - - def withInputPath(inputPath: String): R = { - this.inputPath = Paths.get(inputPath).toAbsolutePath.normalize().toString - this.asInstanceOf[R] - } - - def withOutputPath(x: String): R = { - this.outputPath = x - this.asInstanceOf[R] - } - - var defaultIgnoredFilesRegex: Seq[Regex] = Seq.empty - var ignoredFilesRegex: Regex = "".r - var ignoredFiles: Seq[String] = Seq.empty - - def withDefaultIgnoredFilesRegex(x: Seq[Regex]): R = { - this.defaultIgnoredFilesRegex = x - this.asInstanceOf[R] - } - - def withIgnoredFilesRegex(x: String): R = { - this.ignoredFilesRegex = x.r - this.asInstanceOf[R] - } - - def withIgnoredFiles(x: Seq[String]): R = { - this.ignoredFiles = x.map(createPathForIgnore) - this.asInstanceOf[R] - } - - def createPathForIgnore(ignore: String): String = { - val path = Paths.get(ignore) - if (path.isAbsolute) { path.toString } - else { Paths.get(inputPath, ignore).toAbsolutePath.normalize().toString } - } - - var schemaValidation: ValidationMode = ValidationMode.Disabled - - def withSchemaValidation(value: ValidationMode): R = { - this.schemaValidation = value - this.asInstanceOf[R] - } - - def withInheritedFields(config: R): R = { - this.inputPath = config.inputPath - this.outputPath = config.outputPath - this.defaultIgnoredFilesRegex = config.defaultIgnoredFilesRegex - this.ignoredFilesRegex = config.ignoredFilesRegex - this.ignoredFiles = config.ignoredFiles - this.asInstanceOf[R] - } -} +object X2CpgConfig: + def defaultOutputPath: String = "cpg.bin" + +trait X2CpgConfig[R <: X2CpgConfig[R]]: + var inputPath: String = "" + var outputPath: String = X2CpgConfig.defaultOutputPath + + def withInputPath(inputPath: String): R = + this.inputPath = Paths.get(inputPath).toAbsolutePath.normalize().toString + this.asInstanceOf[R] + + def withOutputPath(x: String): R = + this.outputPath = x + this.asInstanceOf[R] + + var defaultIgnoredFilesRegex: Seq[Regex] = Seq.empty + var ignoredFilesRegex: Regex = "".r + var ignoredFiles: Seq[String] = Seq.empty + + def withDefaultIgnoredFilesRegex(x: Seq[Regex]): R = + this.defaultIgnoredFilesRegex = x + this.asInstanceOf[R] + + def withIgnoredFilesRegex(x: String): R = + this.ignoredFilesRegex = x.r + this.asInstanceOf[R] + + def withIgnoredFiles(x: Seq[String]): R = + this.ignoredFiles = x.map(createPathForIgnore) + this.asInstanceOf[R] + + def createPathForIgnore(ignore: String): String = + val path = Paths.get(ignore) + if path.isAbsolute then path.toString + else Paths.get(inputPath, ignore).toAbsolutePath.normalize().toString + + var schemaValidation: ValidationMode = ValidationMode.Disabled + + def withSchemaValidation(value: ValidationMode): R = + this.schemaValidation = value + this.asInstanceOf[R] + + def withInheritedFields(config: R): R = + this.inputPath = config.inputPath + this.outputPath = config.outputPath + this.defaultIgnoredFilesRegex = config.defaultIgnoredFilesRegex + this.ignoredFilesRegex = config.ignoredFilesRegex + this.ignoredFiles = config.ignoredFiles + this.asInstanceOf[R] +end X2CpgConfig /** Base class for `Main` classes of CPG frontends. * - * Main classes that inherit from this base class parse the command line, exiting with an error code if this does not - * succeed. On success, the method `run` is called, which evaluates, given a frontend and a configuration, creates the - * CPG and stores it on disk. + * Main classes that inherit from this base class parse the command line, exiting with an error + * code if this does not succeed. On success, the method `run` is called, which evaluates, given a + * frontend and a configuration, creates the CPG and stores it on disk. * * @param cmdLineParser * parser for command line arguments * @param frontend * the frontend to use for CPG creation */ -abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[_]](val cmdLineParser: OParser[Unit, T], frontend: X)( +abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[_]]( + val cmdLineParser: OParser[Unit, T], + frontend: X +)( implicit defaultConfig: T -) { - - /** method that evaluates frontend with configuration - */ - def run(config: T, frontend: X): Unit - - def main(args: Array[String]): Unit = { - X2Cpg.parseCommandLine(args, cmdLineParser, defaultConfig) match { - case Some(config) => - try { - run(config, frontend) - } catch { - case ex: Throwable => - println(ex.getMessage) - ex.printStackTrace() - System.exit(1) - } - case None => - println("Error parsing the command line") - System.exit(1) - } - } - -} +): + + /** method that evaluates frontend with configuration + */ + def run(config: T, frontend: X): Unit + + def main(args: Array[String]): Unit = + X2Cpg.parseCommandLine(args, cmdLineParser, defaultConfig) match + case Some(config) => + try + run(config, frontend) + catch + case ex: Throwable => + println(ex.getMessage) + ex.printStackTrace() + System.exit(1) + case None => + println("Error parsing the command line") + System.exit(1) +end X2CpgMain /** Trait that represents a CPG generator, where T is the frontend configuration class. */ -trait X2CpgFrontend[T <: X2CpgConfig[_]] { - - /** Create a CPG according to given configuration. Returns CPG wrapped in a `Try`, making it possible to detect and - * inspect exceptions in CPG generation. To be provided by the frontend. - */ - def createCpg(config: T): Try[Cpg] - - /** Create CPG according to given configuration, printing errors to the console if they occur. The CPG is closed and - * not returned. - */ - def run(config: T): Unit = { - withErrorsToConsole(config) { _ => - createCpg(config) match { - case Success(cpg) => - cpg.close() - Success(cpg) - case Failure(exception) => - Failure(exception) - } - } - } - - /** Create a CPG with default overlays according to given configuration - */ - def createCpgWithOverlays(config: T): Try[Cpg] = { - val maybeCpg = createCpg(config) - maybeCpg.map { cpg => - applyDefaultOverlays(cpg) - cpg - } - } - - /** Create a CPG for code at `inputPath` and apply default overlays. - */ - def createCpgWithOverlays(inputName: String)(implicit defaultConfig: T): Try[Cpg] = { - val maybeCpg = createCpg(inputName) - maybeCpg.map { cpg => - applyDefaultOverlays(cpg) - cpg - } - } - - /** Create a CPG for code at `inputName` (a single location) with default frontend configuration. If `outputName` - * exists, it is the file name of the resulting CPG. Otherwise, the CPG is held in memory. - */ - def createCpg(inputName: String, outputName: Option[String])(implicit defaultConfig: T): Try[Cpg] = { - val defaultWithInputPath = defaultConfig.withInputPath(inputName).asInstanceOf[T] - val config = if (!outputName.contains(X2CpgConfig.defaultOutputPath)) { - if (outputName.isEmpty) { - defaultWithInputPath.withOutputPath("").asInstanceOf[T] - } else { - defaultWithInputPath.withOutputPath(outputName.get).asInstanceOf[T] - } - } else { - defaultWithInputPath - } - createCpg(config) - } - - /** Create a CPG in memory for file at `inputName` with default configuration. - */ - def createCpg(inputName: String)(implicit defaultConfig: T): Try[Cpg] = createCpg(inputName, None)(defaultConfig) - -} - -object X2Cpg { - - private val logger = LoggerFactory.getLogger(X2Cpg.getClass) - - /** Parse commands line arguments in `args` using an X2Cpg command line parser, extended with the frontend specific - * options in `frontendSpecific` with the initial configuration set to `initialConf`. On success, the configuration - * is returned wrapped into an Option. On failure, error messages are printed and, `None` is returned. - */ - def parseCommandLine[R <: X2CpgConfig[R]]( - args: Array[String], - frontendSpecific: OParser[_, R], - initialConf: R - ): Option[R] = { - val parser = commandLineParser(frontendSpecific) - OParser.parse(parser, args, initialConf) - } - - /** Create a command line parser that can be extended to add options specific for the frontend. - */ - private def commandLineParser[R <: X2CpgConfig[R]](frontendSpecific: OParser[_, R]): OParser[_, R] = { - val builder = OParser.builder[R] - import builder.* - OParser.sequence( - arg[String]("input-dir") - .text("source directory") - .action((x, c) => c.withInputPath(x)), - opt[String]("output") - .abbr("o") - .text("output filename") - .action { (x, c) => - c.withOutputPath(x) - }, - opt[Seq[String]]("exclude") - .valueName(",,...") - .action { (x, c) => - c.ignoredFiles = c.ignoredFiles ++ x.map(c.createPathForIgnore) - c +trait X2CpgFrontend[T <: X2CpgConfig[_]]: + + /** Create a CPG according to given configuration. Returns CPG wrapped in a `Try`, making it + * possible to detect and inspect exceptions in CPG generation. To be provided by the frontend. + */ + def createCpg(config: T): Try[Cpg] + + /** Create CPG according to given configuration, printing errors to the console if they occur. + * The CPG is closed and not returned. + */ + def run(config: T): Unit = + withErrorsToConsole(config) { _ => + createCpg(config) match + case Success(cpg) => + cpg.close() + Success(cpg) + case Failure(exception) => + Failure(exception) + } + + /** Create a CPG with default overlays according to given configuration + */ + def createCpgWithOverlays(config: T): Try[Cpg] = + val maybeCpg = createCpg(config) + maybeCpg.map { cpg => + applyDefaultOverlays(cpg) + cpg } - .text("files or folders to exclude during CPG generation (paths relative to or absolute paths)"), - opt[String]("exclude-regex") - .action { (x, c) => - c.ignoredFilesRegex = x.r - c + + /** Create a CPG for code at `inputPath` and apply default overlays. + */ + def createCpgWithOverlays(inputName: String)(implicit defaultConfig: T): Try[Cpg] = + val maybeCpg = createCpg(inputName) + maybeCpg.map { cpg => + applyDefaultOverlays(cpg) + cpg } - .text("a regex specifying files to exclude during CPG generation (paths relative to are matched)"), - opt[Unit]("enable-early-schema-checking") - .action((_, c) => c.withSchemaValidation(ValidationMode.Enabled)) - .text("enables early schema validation during AST creation (disabled by default)"), - help("help").text("display this help message"), - frontendSpecific - ) - } - - /** Create an empty CPG, backed by the file at `optionalOutputPath` or in-memory if `optionalOutputPath` is empty. - */ - def newEmptyCpg(optionalOutputPath: Option[String] = None): Cpg = { - val odbConfig = optionalOutputPath - .map { outputPath => - val outFile = File(outputPath) - if (outputPath != "" && outFile.exists) { - logger.debug("Output file exists, removing: " + outputPath) - outFile.delete() + + /** Create a CPG for code at `inputName` (a single location) with default frontend + * configuration. If `outputName` exists, it is the file name of the resulting CPG. Otherwise, + * the CPG is held in memory. + */ + def createCpg(inputName: String, outputName: Option[String])(implicit + defaultConfig: T + ): Try[Cpg] = + val defaultWithInputPath = defaultConfig.withInputPath(inputName).asInstanceOf[T] + val config = if !outputName.contains(X2CpgConfig.defaultOutputPath) then + if outputName.isEmpty then + defaultWithInputPath.withOutputPath("").asInstanceOf[T] + else + defaultWithInputPath.withOutputPath(outputName.get).asInstanceOf[T] + else + defaultWithInputPath + createCpg(config) + + /** Create a CPG in memory for file at `inputName` with default configuration. + */ + def createCpg(inputName: String)(implicit defaultConfig: T): Try[Cpg] = + createCpg(inputName, None)(defaultConfig) +end X2CpgFrontend + +object X2Cpg: + + private val logger = LoggerFactory.getLogger(X2Cpg.getClass) + + /** Parse commands line arguments in `args` using an X2Cpg command line parser, extended with + * the frontend specific options in `frontendSpecific` with the initial configuration set to + * `initialConf`. On success, the configuration is returned wrapped into an Option. On failure, + * error messages are printed and, `None` is returned. + */ + def parseCommandLine[R <: X2CpgConfig[R]]( + args: Array[String], + frontendSpecific: OParser[?, R], + initialConf: R + ): Option[R] = + val parser = commandLineParser(frontendSpecific) + OParser.parse(parser, args, initialConf) + + /** Create a command line parser that can be extended to add options specific for the frontend. + */ + private def commandLineParser[R <: X2CpgConfig[R]](frontendSpecific: OParser[?, R]) + : OParser[?, R] = + val builder = OParser.builder[R] + import builder.* + OParser.sequence( + arg[String]("input-dir") + .text("source directory") + .action((x, c) => c.withInputPath(x)), + opt[String]("output") + .abbr("o") + .text("output filename") + .action { (x, c) => + c.withOutputPath(x) + }, + opt[Seq[String]]("exclude") + .valueName(",,...") + .action { (x, c) => + c.ignoredFiles = c.ignoredFiles ++ x.map(c.createPathForIgnore) + c + } + .text( + "files or folders to exclude during CPG generation (paths relative to or absolute paths)" + ), + opt[String]("exclude-regex") + .action { (x, c) => + c.ignoredFilesRegex = x.r + c + } + .text( + "a regex specifying files to exclude during CPG generation (paths relative to are matched)" + ), + opt[Unit]("enable-early-schema-checking") + .action((_, c) => c.withSchemaValidation(ValidationMode.Enabled)) + .text("enables early schema validation during AST creation (disabled by default)"), + help("help").text("display this help message"), + frontendSpecific + ) + end commandLineParser + + /** Create an empty CPG, backed by the file at `optionalOutputPath` or in-memory if + * `optionalOutputPath` is empty. + */ + def newEmptyCpg(optionalOutputPath: Option[String] = None): Cpg = + val odbConfig = optionalOutputPath + .map { outputPath => + val outFile = File(outputPath) + if outputPath != "" && outFile.exists then + logger.debug("Output file exists, removing: " + outputPath) + outFile.delete() + Config.withDefaults.withStorageLocation(outputPath) + } + .getOrElse { + Config.withDefaults() + } + Cpg.withConfig(odbConfig) + + /** Apply function `applyPasses` to a newly created CPG. The CPG is wrapped in a `Try` and + * returned. On failure, the CPG is ensured to be closed. + */ + def withNewEmptyCpg[T <: X2CpgConfig[_]]( + outPath: String, + config: T + )(applyPasses: (Cpg, T) => Unit): Try[Cpg] = + val outputPath = if outPath != "" then Some(outPath) else None + Try { + val cpg = newEmptyCpg(outputPath) + Try { + applyPasses(cpg, config) + } match + case Success(_) => cpg + case Failure(exception) => + cpg.close() + throw exception } - Config.withDefaults.withStorageLocation(outputPath) - } - .getOrElse { - Config.withDefaults() - } - Cpg.withConfig(odbConfig) - } - - /** Apply function `applyPasses` to a newly created CPG. The CPG is wrapped in a `Try` and returned. On failure, the - * CPG is ensured to be closed. - */ - def withNewEmptyCpg[T <: X2CpgConfig[_]](outPath: String, config: T)(applyPasses: (Cpg, T) => Unit): Try[Cpg] = { - val outputPath = if (outPath != "") Some(outPath) else None - Try { - val cpg = newEmptyCpg(outputPath) - Try { - applyPasses(cpg, config) - } match { - case Success(_) => cpg - case Failure(exception) => - cpg.close() - throw exception - } - } - } - - /** Given a function that receives a configuration and returns an arbitrary result type wrapped in a `Try`, evaluate - * the function, printing errors to the console. - */ - def withErrorsToConsole[T <: X2CpgConfig[_]](config: T)(f: T => Try[_]): Try[_] = { - f(config) match { - case Failure(exception) => - exception.printStackTrace() - Failure(exception) - case Success(v) => - Success(v) - } - } - - /** For a CPG generated by a frontend, run the default passes that turn a frontend-CPG into a complete CPG. - */ - def applyDefaultOverlays(cpg: Cpg): Unit = { - val context = new LayerCreatorContext(cpg) - defaultOverlayCreators().foreach { creator => - creator.run(context) - } - } - - /** This should be the only place where we define the list of default overlays. - */ - def defaultOverlayCreators(): List[LayerCreator] = { - List(new Base(), new ControlFlow(), new TypeRelations(), new CallGraph()) - } - - /** Write `sourceCode` to a temporary file inside a temporary directory. The prefix for the temporary directory is - * given by `tmpDirPrefix`. The suffix for the temporary file is given by `suffix`. Both file and directory are - * deleted on exit. - */ - def writeCodeToFile(sourceCode: String, tmpDirPrefix: String, suffix: String): java.io.File = { - val tmpDir = Files.createTempDirectory(tmpDirPrefix).toFile - tmpDir.deleteOnExit() - val codeFile = java.io.File.createTempFile("Test", suffix, tmpDir) - codeFile.deleteOnExit() - new PrintWriter(codeFile) { write(sourceCode); close() } - tmpDir - } - - /** Strips surrounding quotation characters from a string. - * @param s - * the target string. - * @return - * the stripped string. - */ - def stripQuotes(str: String): String = str.replaceAll("^(\"|')|(\"|')$", "") - -} + + /** Given a function that receives a configuration and returns an arbitrary result type wrapped + * in a `Try`, evaluate the function, printing errors to the console. + */ + def withErrorsToConsole[T <: X2CpgConfig[_]](config: T)(f: T => Try[?]): Try[?] = + f(config) match + case Failure(exception) => + exception.printStackTrace() + Failure(exception) + case Success(v) => + Success(v) + + /** For a CPG generated by a frontend, run the default passes that turn a frontend-CPG into a + * complete CPG. + */ + def applyDefaultOverlays(cpg: Cpg): Unit = + val context = new LayerCreatorContext(cpg) + defaultOverlayCreators().foreach { creator => + creator.run(context) + } + + /** This should be the only place where we define the list of default overlays. + */ + def defaultOverlayCreators(): List[LayerCreator] = + List(new Base(), new ControlFlow(), new TypeRelations(), new CallGraph()) + + /** Write `sourceCode` to a temporary file inside a temporary directory. The prefix for the + * temporary directory is given by `tmpDirPrefix`. The suffix for the temporary file is given + * by `suffix`. Both file and directory are deleted on exit. + */ + def writeCodeToFile(sourceCode: String, tmpDirPrefix: String, suffix: String): java.io.File = + val tmpDir = Files.createTempDirectory(tmpDirPrefix).toFile + tmpDir.deleteOnExit() + val codeFile = java.io.File.createTempFile("Test", suffix, tmpDir) + codeFile.deleteOnExit() + new PrintWriter(codeFile): + write(sourceCode); close() + tmpDir + + /** Strips surrounding quotation characters from a string. + * @param s + * the target string. + * @return + * the stripped string. + */ + def stripQuotes(str: String): String = str.replaceAll("^(\"|')|(\"|')$", "") +end X2Cpg diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala index bf72a551..57fffac3 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala @@ -2,8 +2,6 @@ package io.appthreat.x2cpg.datastructures import java.util.concurrent.ConcurrentHashMap -class Global { +class Global: - val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() - -} + val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala index f9418ccd..2c09bc26 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala @@ -9,37 +9,30 @@ package io.appthreat.x2cpg.datastructures * @tparam S * Scope type. */ -class Scope[I, V, S] { - protected var stack: List[ScopeElement[I, V, S]] = List[ScopeElement[I, V, S]]() +class Scope[I, V, S]: + protected var stack: List[ScopeElement[I, V, S]] = List[ScopeElement[I, V, S]]() - def isEmpty: Boolean = { - stack.isEmpty - } + def isEmpty: Boolean = + stack.isEmpty - def pushNewScope(scopeNode: S): Unit = { - stack = ScopeElement[I, V, S](scopeNode) :: stack - } + def pushNewScope(scopeNode: S): Unit = + stack = ScopeElement[I, V, S](scopeNode) :: stack - def popScope(): Option[S] = { - stack match { - case Nil => None + def popScope(): Option[S] = + stack match + case Nil => None - case head :: tail => - stack = tail - Some(head.scopeNode) - } - } + case head :: tail => + stack = tail + Some(head.scopeNode) - def addToScope(identifier: I, variable: V): S = { - stack = stack.head.addVariable(identifier, variable) :: stack.tail - stack.head.scopeNode - } + def addToScope(identifier: I, variable: V): S = + stack = stack.head.addVariable(identifier, variable) :: stack.tail + stack.head.scopeNode - def lookupVariable(identifier: I): Option[V] = { - stack.collectFirst { - case scopeElement if scopeElement.variables.contains(identifier) => - scopeElement.variables(identifier) - } - } - -} + def lookupVariable(identifier: I): Option[V] = + stack.collectFirst { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement.variables(identifier) + } +end Scope diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala index ea1f27fa..fed9d6f4 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala @@ -8,8 +8,6 @@ package io.appthreat.x2cpg.datastructures * @tparam S * Scope type. */ -case class ScopeElement[I, V, S](scopeNode: S, variables: Map[I, V] = Map[I, V]()) { - def addVariable(identifier: I, variable: V): ScopeElement[I, V, S] = { - ScopeElement(scopeNode, variables + (identifier -> variable)) - } -} +case class ScopeElement[I, V, S](scopeNode: S, variables: Map[I, V] = Map[I, V]()): + def addVariable(identifier: I, variable: V): ScopeElement[I, V, S] = + ScopeElement(scopeNode, variables + (identifier -> variable)) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala index b77f8dc9..f197cfd1 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala @@ -2,18 +2,13 @@ package io.appthreat.x2cpg.datastructures import scala.collection.mutable -object Stack { +object Stack: - type Stack[StackElement] = mutable.ListBuffer[StackElement] + type Stack[StackElement] = mutable.ListBuffer[StackElement] - implicit class StackWrapper[StackElement](val parentStack: Stack[StackElement]) extends AnyVal { - def push(parent: StackElement): Unit = { - parentStack.prepend(parent) - } + implicit class StackWrapper[StackElement](val parentStack: Stack[StackElement]) extends AnyVal: + def push(parent: StackElement): Unit = + parentStack.prepend(parent) - def pop(): Unit = { - parentStack.remove(0) - } - } - -} + def pop(): Unit = + parentStack.remove(0) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala index 7f546714..cefbae62 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala @@ -3,40 +3,36 @@ package io.appthreat.x2cpg.layers import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.passes.CpgPassBase -import io.appthreat.x2cpg.passes.base._ +import io.appthreat.x2cpg.passes.base.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -object Base { - val overlayName: String = "base" - val description: String = "base layer (linked frontend CPG)" - def defaultOpts = new LayerCreatorOptions() +object Base: + val overlayName: String = "base" + val description: String = "base layer (linked frontend CPG)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator( - new FileCreationPass(cpg), - new NamespaceCreator(cpg), - new TypeDeclStubCreator(cpg), - new MethodStubCreator(cpg), - new ParameterIndexCompatPass(cpg), - new MethodDecoratorPass(cpg), - new AstLinkerPass(cpg), - new ContainsEdgePass(cpg), - new TypeUsagePass(cpg) - ) + def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator( + new FileCreationPass(cpg), + new NamespaceCreator(cpg), + new TypeDeclStubCreator(cpg), + new MethodStubCreator(cpg), + new ParameterIndexCompatPass(cpg), + new MethodDecoratorPass(cpg), + new AstLinkerPass(cpg), + new ContainsEdgePass(cpg), + new TypeUsagePass(cpg) + ) -} +class Base extends LayerCreator: + override val overlayName: String = Base.overlayName + override val description: String = Base.description -class Base extends LayerCreator { - override val overlayName: String = Base.overlayName - override val description: String = Base.description + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.graph.indexManager.createNodePropertyIndex(PropertyNames.FULL_NAME) + Base.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.graph.indexManager.createNodePropertyIndex(PropertyNames.FULL_NAME) - Base.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } - } - - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() -} + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala index 1695faa4..3e104604 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala @@ -5,29 +5,24 @@ import io.shiftleft.passes.CpgPassBase import io.appthreat.x2cpg.passes.callgraph.{DynamicCallLinker, MethodRefLinker, StaticCallLinker} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -object CallGraph { - val overlayName: String = "callgraph" - val description: String = "Call graph layer" - def defaultOpts = new LayerCreatorOptions() +object CallGraph: + val overlayName: String = "callgraph" + val description: String = "Call graph layer" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = { - Iterator(new MethodRefLinker(cpg), new StaticCallLinker(cpg), new DynamicCallLinker(cpg)) - } + def passes(cpg: Cpg): Iterator[CpgPassBase] = + Iterator(new MethodRefLinker(cpg), new StaticCallLinker(cpg), new DynamicCallLinker(cpg)) -} +class CallGraph extends LayerCreator: + override val overlayName: String = CallGraph.overlayName + override val description: String = CallGraph.description + override val dependsOn: List[String] = List(TypeRelations.overlayName) -class CallGraph extends LayerCreator { - override val overlayName: String = CallGraph.overlayName - override val description: String = CallGraph.description - override val dependsOn: List[String] = List(TypeRelations.overlayName) + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + CallGraph.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - CallGraph.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } - } - - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() -} + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala index 2cfc5104..136a1a73 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala @@ -3,40 +3,34 @@ package io.appthreat.x2cpg.layers import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.passes.CpgPassBase -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.appthreat.x2cpg.passes.controlflow.CfgCreationPass import io.appthreat.x2cpg.passes.controlflow.cfgdominator.CfgDominatorPass import io.appthreat.x2cpg.passes.controlflow.codepencegraph.CdgPass import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -object ControlFlow { - val overlayName: String = "controlflow" - val description: String = "Control flow layer (including dominators and CDG edges)" - def defaultOpts = new LayerCreatorOptions() +object ControlFlow: + val overlayName: String = "controlflow" + val description: String = "Control flow layer (including dominators and CDG edges)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = { - val cfgCreationPass = cpg.metaData.language.lastOption match { - case Some(Languages.GHIDRA) => Iterator[CpgPassBase]() - case Some(Languages.LLVM) => Iterator[CpgPassBase]() - case _ => Iterator[CpgPassBase](new CfgCreationPass(cpg)) - } - cfgCreationPass ++ Iterator(new CfgDominatorPass(cpg), new CdgPass(cpg)) - } + def passes(cpg: Cpg): Iterator[CpgPassBase] = + val cfgCreationPass = cpg.metaData.language.lastOption match + case Some(Languages.GHIDRA) => Iterator[CpgPassBase]() + case Some(Languages.LLVM) => Iterator[CpgPassBase]() + case _ => Iterator[CpgPassBase](new CfgCreationPass(cpg)) + cfgCreationPass ++ Iterator(new CfgDominatorPass(cpg), new CdgPass(cpg)) -} +class ControlFlow extends LayerCreator: + override val overlayName: String = ControlFlow.overlayName + override val description: String = ControlFlow.description + override val dependsOn: List[String] = List(Base.overlayName) -class ControlFlow extends LayerCreator { - override val overlayName: String = ControlFlow.overlayName - override val description: String = ControlFlow.description - override val dependsOn: List[String] = List(Base.overlayName) + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + ControlFlow.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - ControlFlow.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } - } - - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() -} + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala index 9ee6b7b6..7a7b2b47 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala @@ -1,31 +1,27 @@ package io.appthreat.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class AstDumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpAst { +object DumpAst: - val overlayName = "dumpAst" + val overlayName = "dumpAst" - val description = "Dump abstract syntax trees to out/" + val description = "Dump abstract syntax trees to out/" - def defaultOpts: AstDumpOptions = AstDumpOptions("out") -} + def defaultOpts: AstDumpOptions = AstDumpOptions("out") -class DumpAst(options: AstDumpOptions) extends LayerCreator { - override val overlayName: String = DumpAst.overlayName - override val description: String = DumpAst.description - override val storeOverlayName: Boolean = false +class DumpAst(options: AstDumpOptions) extends LayerCreator: + override val overlayName: String = DumpAst.overlayName + override val description: String = DumpAst.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotAst.head - (File(options.outDir) / s"$i-ast.dot").write(str) - } - } - -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotAst.head + (File(options.outDir) / s"$i-ast.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala index 513601a8..e981900a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala @@ -1,30 +1,27 @@ package io.appthreat.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class CdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpCdg { +object DumpCdg: - val overlayName = "dumpCdg" + val overlayName = "dumpCdg" - val description = "Dump control dependence graph to out/" + val description = "Dump control dependence graph to out/" - def defaultOpts: CdgDumpOptions = CdgDumpOptions("out") -} + def defaultOpts: CdgDumpOptions = CdgDumpOptions("out") -class DumpCdg(options: CdgDumpOptions) extends LayerCreator { - override val overlayName: String = DumpCdg.overlayName - override val description: String = DumpCdg.description - override val storeOverlayName: Boolean = false +class DumpCdg(options: CdgDumpOptions) extends LayerCreator: + override val overlayName: String = DumpCdg.overlayName + override val description: String = DumpCdg.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCdg.head - (File(options.outDir) / s"$i-cdg.dot").write(str) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCdg.head + (File(options.outDir) / s"$i-cdg.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala index 22fef5ed..c80adff4 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala @@ -1,30 +1,27 @@ package io.appthreat.x2cpg.layers import better.files.File -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} case class CfgDumpOptions(var outDir: String) extends LayerCreatorOptions {} -object DumpCfg { +object DumpCfg: - val overlayName = "dumpCfg" + val overlayName = "dumpCfg" - val description = "Dump control flow graph to out/" + val description = "Dump control flow graph to out/" - def defaultOpts: CfgDumpOptions = CfgDumpOptions("out") -} + def defaultOpts: CfgDumpOptions = CfgDumpOptions("out") -class DumpCfg(options: CfgDumpOptions) extends LayerCreator { - override val overlayName: String = DumpCfg.overlayName - override val description: String = DumpCfg.description - override val storeOverlayName: Boolean = false +class DumpCfg(options: CfgDumpOptions) extends LayerCreator: + override val overlayName: String = DumpCfg.overlayName + override val description: String = DumpCfg.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCfg.head - (File(options.outDir) / s"$i-cfg.dot").write(str) - } - } -} + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCfg.head + (File(options.outDir) / s"$i-cfg.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala index febd5d7c..419d406c 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala @@ -5,26 +5,24 @@ import io.shiftleft.passes.CpgPassBase import io.appthreat.x2cpg.passes.typerelations.{AliasLinkerPass, TypeHierarchyPass} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -object TypeRelations { - val overlayName: String = "typerel" - val description: String = "Type relations layer (hierarchy and aliases)" - def defaultOpts = new LayerCreatorOptions() +object TypeRelations: + val overlayName: String = "typerel" + val description: String = "Type relations layer (hierarchy and aliases)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg)) -} + def passes(cpg: Cpg): Iterator[CpgPassBase] = + Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg)) -class TypeRelations extends LayerCreator { - override val overlayName: String = TypeRelations.overlayName - override val description: String = TypeRelations.description - override val dependsOn: List[String] = List(Base.overlayName) +class TypeRelations extends LayerCreator: + override val overlayName: String = TypeRelations.overlayName + override val description: String = TypeRelations.description + override val dependsOn: List[String] = List(Base.overlayName) - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = { - val cpg = context.cpg - TypeRelations.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + TypeRelations.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - // Layers need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() -} + // Layers need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala index 5f1eeb5a..7b4b6dd6 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala @@ -5,49 +5,67 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -class AstLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { +class AstLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = { - cpg.method.whereNot(_.astParent).foreach { method => - addAstParent(method, method.fullName, method.astParentType, method.astParentFullName, dstGraph) - } - cpg.typeDecl.whereNot(_.astParent).foreach { typeDecl => - addAstParent(typeDecl, typeDecl.fullName, typeDecl.astParentType, typeDecl.astParentFullName, dstGraph) - } - } + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.method.whereNot(_.astParent).foreach { method => + addAstParent( + method, + method.fullName, + method.astParentType, + method.astParentFullName, + dstGraph + ) + } + cpg.typeDecl.whereNot(_.astParent).foreach { typeDecl => + addAstParent( + typeDecl, + typeDecl.fullName, + typeDecl.astParentType, + typeDecl.astParentFullName, + dstGraph + ) + } + end run - /** For the given method or type declaration, determine its parent in the AST via the AST_PARENT_TYPE and - * AST_PARENT_FULL_NAME fields and create an AST edge from the parent to it. AST creation to methods and type - * declarations is deferred in frontends in order to allow them to process methods/type- declarations independently. - */ - private def addAstParent( - astChild: StoredNode, - astChildFullName: String, - astParentType: String, - astParentFullName: String, - dstGraph: DiffGraphBuilder - ): Unit = { - val astParentOption: Option[StoredNode] = - astParentType match { - case NodeTypes.METHOD => methodFullNameToNode(cpg, astParentFullName) - case NodeTypes.TYPE_DECL => typeDeclFullNameToNode(cpg, astParentFullName) - case NodeTypes.NAMESPACE_BLOCK => namespaceBlockFullNameToNode(cpg, astParentFullName) - case _ => - logger.debug( - s"Invalid AST_PARENT_TYPE=$astParentFullName;" + - s" astChild LABEL=${astChild.label};" + - s" astChild FULL_NAME=$astChildFullName" - ) - None - } + /** For the given method or type declaration, determine its parent in the AST via the + * AST_PARENT_TYPE and AST_PARENT_FULL_NAME fields and create an AST edge from the parent to + * it. AST creation to methods and type declarations is deferred in frontends in order to allow + * them to process methods/type- declarations independently. + */ + private def addAstParent( + astChild: StoredNode, + astChildFullName: String, + astParentType: String, + astParentFullName: String, + dstGraph: DiffGraphBuilder + ): Unit = + val astParentOption: Option[StoredNode] = + astParentType match + case NodeTypes.METHOD => methodFullNameToNode(cpg, astParentFullName) + case NodeTypes.TYPE_DECL => typeDeclFullNameToNode(cpg, astParentFullName) + case NodeTypes.NAMESPACE_BLOCK => + namespaceBlockFullNameToNode(cpg, astParentFullName) + case _ => + logger.debug( + s"Invalid AST_PARENT_TYPE=$astParentFullName;" + + s" astChild LABEL=${astChild.label};" + + s" astChild FULL_NAME=$astChildFullName" + ) + None - astParentOption match { - case Some(astParent) => - dstGraph.addEdge(astParent, astChild, EdgeTypes.AST) - case None => - logFailedSrcLookup(EdgeTypes.AST, astParentType, astParentFullName, astChild.label, astChild.id.toString) - } - } -} + astParentOption match + case Some(astParent) => + dstGraph.addEdge(astParent, astChild, EdgeTypes.AST) + case None => + logFailedSrcLookup( + EdgeTypes.AST, + astParentType, + astParentFullName, + astChild.label, + astChild.id.toString + ) + end addAstParent +end AstLinkerPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala index 5f81215e..88855d61 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala @@ -1,49 +1,42 @@ package io.appthreat.x2cpg.passes.base import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.passes.ConcurrentWriterCpgPass import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -/** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends which do not provide - * method stubs and type decl stubs. +/** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends + * which do not provide method stubs and type decl stubs. */ -class ContainsEdgePass(cpg: Cpg) extends ConcurrentWriterCpgPass[AstNode](cpg) { - import ContainsEdgePass._ - - override def generateParts(): Array[AstNode] = - cpg.graph.nodes(sourceTypes: _*).asScala.map(_.asInstanceOf[AstNode]).toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, source: AstNode): Unit = { - // AST is assumed to be a tree. If it contains cycles, then this will give a nice endless loop with OOM - val queue = mutable.ArrayDeque[StoredNode](source) - while (queue.nonEmpty) { - val parent = queue.removeHead() - for (nextNode <- parent._astOut) { - if (isDestinationType(nextNode)) dstGraph.addEdge(source, nextNode, EdgeTypes.CONTAINS) - if (!isSourceType(nextNode)) queue.append(nextNode) - } - } - } -} - -object ContainsEdgePass { - - private def isSourceType(node: StoredNode): Boolean = node match { - case _: Method | _: TypeDecl | _: File => true - case _ => false - } - - private def isDestinationType(node: StoredNode): Boolean = node match { - case _: Block | _: Identifier | _: FieldIdentifier | _: Return | _: Method | _: TypeDecl | _: Call | _: Literal | - _: MethodRef | _: TypeRef | _: ControlStructure | _: JumpTarget | _: Unknown | _: TemplateDom => - true - case _ => false - } - - private val sourceTypes = List(NodeTypes.METHOD, NodeTypes.TYPE_DECL, NodeTypes.FILE) - -} +class ContainsEdgePass(cpg: Cpg) extends ConcurrentWriterCpgPass[AstNode](cpg): + import ContainsEdgePass.* + + override def generateParts(): Array[AstNode] = + cpg.graph.nodes(sourceTypes*).asScala.map(_.asInstanceOf[AstNode]).toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, source: AstNode): Unit = + // AST is assumed to be a tree. If it contains cycles, then this will give a nice endless loop with OOM + val queue = mutable.ArrayDeque[StoredNode](source) + while queue.nonEmpty do + val parent = queue.removeHead() + for nextNode <- parent._astOut do + if isDestinationType(nextNode) then + dstGraph.addEdge(source, nextNode, EdgeTypes.CONTAINS) + if !isSourceType(nextNode) then queue.append(nextNode) + +object ContainsEdgePass: + + private def isSourceType(node: StoredNode): Boolean = node match + case _: Method | _: TypeDecl | _: File => true + case _ => false + + private def isDestinationType(node: StoredNode): Boolean = node match + case _: Block | _: Identifier | _: FieldIdentifier | _: Return | _: Method | _: TypeDecl | _: Call | _: Literal | + _: MethodRef | _: TypeRef | _: ControlStructure | _: JumpTarget | _: Unknown | _: TemplateDom => + true + case _ => false + + private val sourceTypes = List(NodeTypes.METHOD, NodeTypes.TYPE_DECL, NodeTypes.FILE) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala index 7df39585..e6a9490a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala @@ -4,53 +4,54 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{NewFile, StoredNode} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.appthreat.x2cpg.utils.LinkingUtil import io.shiftleft.semanticcpg.language.types.structure.FileTraversal import scala.collection.mutable -/** For all nodes with FILENAME fields, create corresponding FILE nodes and connect node with FILE node via outgoing - * SOURCE_FILE edges. +/** For all nodes with FILENAME fields, create corresponding FILE nodes and connect node with FILE + * node via outgoing SOURCE_FILE edges. */ -class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { - - override def run(dstGraph: DiffGraphBuilder): Unit = { - val originalFileNameToNode = mutable.Map.empty[String, StoredNode] - val newFileNameToNode = mutable.Map.empty[String, NewFile] - - cpg.file.foreach { node => - originalFileNameToNode += node.name -> node - } - - def createFileIfDoesNotExist(srcNode: StoredNode, destFullName: String): Unit = { - if (destFullName != srcNode.propertyDefaultValue(PropertyNames.FILENAME)) { - val dstFullName = if (destFullName == "") { FileTraversal.UNKNOWN } - else { destFullName } - val newFile = newFileNameToNode.getOrElseUpdate( - dstFullName, { - val file = NewFile().name(dstFullName).order(0) - dstGraph.addNode(file) - file - } +class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: + + override def run(dstGraph: DiffGraphBuilder): Unit = + val originalFileNameToNode = mutable.Map.empty[String, StoredNode] + val newFileNameToNode = mutable.Map.empty[String, NewFile] + + cpg.file.foreach { node => + originalFileNameToNode += node.name -> node + } + + def createFileIfDoesNotExist(srcNode: StoredNode, destFullName: String): Unit = + if destFullName != srcNode.propertyDefaultValue(PropertyNames.FILENAME) then + val dstFullName = if destFullName == "" then FileTraversal.UNKNOWN + else destFullName + val newFile = newFileNameToNode.getOrElseUpdate( + dstFullName, { + val file = NewFile().name(dstFullName).order(0) + dstGraph.addNode(file) + file + } + ) + dstGraph.addEdge(srcNode, newFile, EdgeTypes.SOURCE_FILE) + + // Create SOURCE_FILE edges from nodes of various types to FILE + linkToSingle( + cpg, + srcLabels = List( + NodeTypes.NAMESPACE_BLOCK, + NodeTypes.TYPE_DECL, + NodeTypes.METHOD, + NodeTypes.COMMENT + ), + dstNodeLabel = NodeTypes.FILE, + edgeType = EdgeTypes.SOURCE_FILE, + dstNodeMap = x => + originalFileNameToNode.get(x), + dstFullNameKey = PropertyNames.FILENAME, + dstGraph, + Some(createFileIfDoesNotExist) ) - dstGraph.addEdge(srcNode, newFile, EdgeTypes.SOURCE_FILE) - } - } - - // Create SOURCE_FILE edges from nodes of various types to FILE - linkToSingle( - cpg, - srcLabels = List(NodeTypes.NAMESPACE_BLOCK, NodeTypes.TYPE_DECL, NodeTypes.METHOD, NodeTypes.COMMENT), - dstNodeLabel = NodeTypes.FILE, - edgeType = EdgeTypes.SOURCE_FILE, - dstNodeMap = { x => - originalFileNameToNode.get(x) - }, - dstFullNameKey = PropertyNames.FILENAME, - dstGraph, - Some(createFileIfDoesNotExist) - ) - } - -} + end run +end FileCreationPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala index 94641bf9..a9a993b3 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala @@ -3,60 +3,58 @@ package io.appthreat.x2cpg.passes.base import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} -/** Adds a METHOD_PARAMETER_OUT for each METHOD_PARAMETER_IN to the graph and connects those with a PARAMETER_LINK edge. - * It also creates an AST edge from METHOD to the new METHOD_PARAMETER_OUT nodes. +/** Adds a METHOD_PARAMETER_OUT for each METHOD_PARAMETER_IN to the graph and connects those with a + * PARAMETER_LINK edge. It also creates an AST edge from METHOD to the new METHOD_PARAMETER_OUT + * nodes. * - * This pass has MethodStubCreator as prerequisite for language frontends which do not provide method stubs. + * This pass has MethodStubCreator as prerequisite for language frontends which do not provide + * method stubs. */ -class MethodDecoratorPass(cpg: Cpg) extends CpgPass(cpg) { - import MethodDecoratorPass.logger +class MethodDecoratorPass(cpg: Cpg) extends CpgPass(cpg): + import MethodDecoratorPass.logger - private[this] var loggedDeprecatedWarning = false - private[this] var loggedMissingTypeFullName = false + private[this] var loggedDeprecatedWarning = false + private[this] var loggedMissingTypeFullName = false - override def run(dstGraph: DiffGraphBuilder): Unit = { - cpg.parameter.foreach { parameterIn => - if (!parameterIn._parameterLinkOut.hasNext) { - val parameterOut = nodes - .NewMethodParameterOut() - .code(parameterIn.code) - .order(parameterIn.order) - .index(parameterIn.index) - .name(parameterIn.name) - .evaluationStrategy(parameterIn.evaluationStrategy) - .typeFullName(parameterIn.typeFullName) - .isVariadic(parameterIn.isVariadic) - .lineNumber(parameterIn.lineNumber) - .columnNumber(parameterIn.columnNumber) + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.parameter.foreach { parameterIn => + if !parameterIn._parameterLinkOut.hasNext then + val parameterOut = nodes + .NewMethodParameterOut() + .code(parameterIn.code) + .order(parameterIn.order) + .index(parameterIn.index) + .name(parameterIn.name) + .evaluationStrategy(parameterIn.evaluationStrategy) + .typeFullName(parameterIn.typeFullName) + .isVariadic(parameterIn.isVariadic) + .lineNumber(parameterIn.lineNumber) + .columnNumber(parameterIn.columnNumber) - val method = parameterIn.astIn.headOption - if (method.isEmpty) { - logger.debug("Parameter without method encountered: " + parameterIn.toString) - } else { - if (parameterIn.typeFullName == null) { - val evalType = parameterIn.typ - dstGraph.addEdge(parameterOut, evalType, EdgeTypes.EVAL_TYPE) - if (!loggedMissingTypeFullName) { - logger.debug("Using deprecated CPG format with missing TYPE_FULL_NAME on METHOD_PARAMETER_IN nodes.") - loggedMissingTypeFullName = true - } - } + val method = parameterIn.astIn.headOption + if method.isEmpty then + logger.debug("Parameter without method encountered: " + parameterIn.toString) + else + if parameterIn.typeFullName == null then + val evalType = parameterIn.typ + dstGraph.addEdge(parameterOut, evalType, EdgeTypes.EVAL_TYPE) + if !loggedMissingTypeFullName then + logger.debug( + "Using deprecated CPG format with missing TYPE_FULL_NAME on METHOD_PARAMETER_IN nodes." + ) + loggedMissingTypeFullName = true - dstGraph.addNode(parameterOut) - dstGraph.addEdge(method.get, parameterOut, EdgeTypes.AST) - dstGraph.addEdge(parameterIn, parameterOut, EdgeTypes.PARAMETER_LINK) + dstGraph.addNode(parameterOut) + dstGraph.addEdge(method.get, parameterOut, EdgeTypes.AST) + dstGraph.addEdge(parameterIn, parameterOut, EdgeTypes.PARAMETER_LINK) + else if !loggedDeprecatedWarning then + logger.debug("Using deprecated CPG format with PARAMETER_LINK edges") + loggedDeprecatedWarning = true } - } else if (!loggedDeprecatedWarning) { - logger.debug("Using deprecated CPG format with PARAMETER_LINK edges") - loggedDeprecatedWarning = true - } - } - } -} +end MethodDecoratorPass -object MethodDecoratorPass { - private val logger: Logger = LoggerFactory.getLogger(classOf[MethodDecoratorPass]) -} +object MethodDecoratorPass: + private val logger: Logger = LoggerFactory.getLogger(classOf[MethodDecoratorPass]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala index f8ba5c4e..19c72046 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala @@ -3,10 +3,15 @@ package io.appthreat.x2cpg.passes.base import io.appthreat.x2cpg.Defines import io.appthreat.x2cpg.passes.base.MethodStubCreator.createMethodStub import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, EvaluationStrategies, NodeTypes} +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{ + DispatchTypes, + EdgeTypes, + EvaluationStrategies, + NodeTypes +} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder @@ -17,125 +22,117 @@ case class CallSummary(name: String, signature: String, fullName: String, dispat /** This pass has no other pass as prerequisite. */ -class MethodStubCreator(cpg: Cpg) extends CpgPass(cpg) { - - // Since the method fullNames for fuzzyc are not unique, we do not have - // a 1to1 relation and may overwrite some values. This is ok for now. - private val methodFullNameToNode = mutable.LinkedHashMap[String, Method]() - private val methodToParameterCount = mutable.LinkedHashMap[CallSummary, Int]() - - override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = { - for (method <- cpg.method) { - methodFullNameToNode.put(method.fullName, method) - } - - for (call <- cpg.call if call.methodFullName != Defines.DynamicCallUnknownFullName) { - methodToParameterCount.put( - CallSummary(call.name, call.signature, call.methodFullName, call.dispatchType), - call.argument.size - ) - } - - for ( - (CallSummary(name, signature, fullName, dispatchType), parameterCount) <- methodToParameterCount - if !methodFullNameToNode.contains(fullName) - ) { - createMethodStub(name, fullName, signature, dispatchType, parameterCount, dstGraph) - } - } - - override def finish(): Unit = { - methodFullNameToNode.clear() - methodToParameterCount.clear() - super.finish() - } - -} - -object MethodStubCreator { - - private def addLineNumberInfo(methodNode: NewMethod, fullName: String): NewMethod = { - val s = fullName.split(":") - if ( - s.size == 5 && Try { - s(1).toInt - }.isSuccess && Try { - s(2).toInt - }.isSuccess - ) { - val filename = s(0) - val lineNumber = s(1).toInt - val lineNumberEnd = s(2).toInt - methodNode - .filename(filename) - .lineNumber(lineNumber) - .lineNumberEnd(lineNumberEnd) - } else { - methodNode - } - } - - def createMethodStub( - name: String, - fullName: String, - signature: String, - dispatchType: String, - parameterCount: Int, - dstGraph: DiffGraphBuilder, - isExternal: Boolean = true, - astParentType: String = NodeTypes.NAMESPACE_BLOCK, - astParentFullName: String = "" - ): NewMethod = { - val methodNode = NewMethod() - .name(name) - .fullName(fullName) - .isExternal(isExternal) - .signature(signature) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .order(0) - - addLineNumberInfo(methodNode, fullName) - - dstGraph.addNode(methodNode) - - val firstParameterIndex = dispatchType match { - case DispatchTypes.DYNAMIC_DISPATCH => - 0 - case _ => - 1 - } - - (firstParameterIndex to parameterCount).foreach { parameterOrder => - val nameAndCode = s"p$parameterOrder" - val param = NewMethodParameterIn() - .code(nameAndCode) - .order(parameterOrder) - .name(nameAndCode) - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .typeFullName("ANY") - - dstGraph.addNode(param) - dstGraph.addEdge(methodNode, param, EdgeTypes.AST) - } - - val blockNode = NewBlock() - .order(1) - .argumentIndex(1) - .typeFullName("ANY") - - dstGraph.addNode(blockNode) - dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) - - val methodReturn = NewMethodReturn() - .order(2) - .code("RET") - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .typeFullName("ANY") - - dstGraph.addNode(methodReturn) - dstGraph.addEdge(methodNode, methodReturn, EdgeTypes.AST) - - methodNode - } -} +class MethodStubCreator(cpg: Cpg) extends CpgPass(cpg): + + // Since the method fullNames for fuzzyc are not unique, we do not have + // a 1to1 relation and may overwrite some values. This is ok for now. + private val methodFullNameToNode = mutable.LinkedHashMap[String, Method]() + private val methodToParameterCount = mutable.LinkedHashMap[CallSummary, Int]() + + override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = + for method <- cpg.method do + methodFullNameToNode.put(method.fullName, method) + + for call <- cpg.call if call.methodFullName != Defines.DynamicCallUnknownFullName do + methodToParameterCount.put( + CallSummary(call.name, call.signature, call.methodFullName, call.dispatchType), + call.argument.size + ) + + for + (CallSummary(name, signature, fullName, dispatchType), parameterCount) <- + methodToParameterCount + if !methodFullNameToNode.contains(fullName) + do + createMethodStub(name, fullName, signature, dispatchType, parameterCount, dstGraph) + + override def finish(): Unit = + methodFullNameToNode.clear() + methodToParameterCount.clear() + super.finish() +end MethodStubCreator + +object MethodStubCreator: + + private def addLineNumberInfo(methodNode: NewMethod, fullName: String): NewMethod = + val s = fullName.split(":") + if + s.size == 5 && Try { + s(1).toInt + }.isSuccess && Try { + s(2).toInt + }.isSuccess + then + val filename = s(0) + val lineNumber = s(1).toInt + val lineNumberEnd = s(2).toInt + methodNode + .filename(filename) + .lineNumber(lineNumber) + .lineNumberEnd(lineNumberEnd) + else + methodNode + + def createMethodStub( + name: String, + fullName: String, + signature: String, + dispatchType: String, + parameterCount: Int, + dstGraph: DiffGraphBuilder, + isExternal: Boolean = true, + astParentType: String = NodeTypes.NAMESPACE_BLOCK, + astParentFullName: String = "" + ): NewMethod = + val methodNode = NewMethod() + .name(name) + .fullName(fullName) + .isExternal(isExternal) + .signature(signature) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .order(0) + + addLineNumberInfo(methodNode, fullName) + + dstGraph.addNode(methodNode) + + val firstParameterIndex = dispatchType match + case DispatchTypes.DYNAMIC_DISPATCH => + 0 + case _ => + 1 + + (firstParameterIndex to parameterCount).foreach { parameterOrder => + val nameAndCode = s"p$parameterOrder" + val param = NewMethodParameterIn() + .code(nameAndCode) + .order(parameterOrder) + .name(nameAndCode) + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(param) + dstGraph.addEdge(methodNode, param, EdgeTypes.AST) + } + + val blockNode = NewBlock() + .order(1) + .argumentIndex(1) + .typeFullName("ANY") + + dstGraph.addNode(blockNode) + dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) + + val methodReturn = NewMethodReturn() + .order(2) + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(methodReturn) + dstGraph.addEdge(methodNode, methodReturn, EdgeTypes.AST) + + methodNode + end createMethodStub +end MethodStubCreator diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala index 853eb568..6d08e718 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala @@ -4,24 +4,21 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.NewNamespace import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. * * This pass has no other pass as prerequisite. */ -class NamespaceCreator(cpg: Cpg) extends CpgPass(cpg) { +class NamespaceCreator(cpg: Cpg) extends CpgPass(cpg): - /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. - */ - override def run(dstGraph: DiffGraphBuilder): Unit = { - cpg.namespaceBlock - .groupBy(_.name) - .foreach { case (name: String, blocks) => - val namespace = NewNamespace().name(name) - dstGraph.addNode(namespace) - blocks.foreach(block => dstGraph.addEdge(block, namespace, EdgeTypes.REF)) - } - } - -} + /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. + */ + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.namespaceBlock + .groupBy(_.name) + .foreach { case (name: String, blocks) => + val namespace = NewNamespace().name(name) + dstGraph.addNode(namespace) + blocks.foreach(block => dstGraph.addEdge(block, namespace, EdgeTypes.REF)) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala index da129c6a..f59f8d25 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala @@ -5,19 +5,16 @@ import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn.PropertyDefaults import io.shiftleft.passes.CpgPass import overflowdb.BatchedUpdate -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -/** Old CPGs use the `order` field to indicate the parameter index while newer CPGs use the `parameterIndex` field. This - * pass checks whether `parameterIndex` is not set, in which case the value of `order` is copied over. +/** Old CPGs use the `order` field to indicate the parameter index while newer CPGs use the + * `parameterIndex` field. This pass checks whether `parameterIndex` is not set, in which case the + * value of `order` is copied over. */ -class ParameterIndexCompatPass(cpg: Cpg) extends CpgPass(cpg) { +class ParameterIndexCompatPass(cpg: Cpg) extends CpgPass(cpg): - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = { - cpg.parameter.foreach { param => - if (param.index == PropertyDefaults.Index) { - diffGraph.setNodeProperty(param, PropertyNames.INDEX, param.order) - } - } - } - -} + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.parameter.foreach { param => + if param.index == PropertyDefaults.Index then + diffGraph.setNodeProperty(param, PropertyNames.INDEX, param.order) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala index bdc54cfb..f5717955 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala @@ -4,55 +4,50 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.NodeTypes import io.shiftleft.codepropertygraph.generated.nodes.{NewTypeDecl, TypeDeclBase} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, NamespaceTraversal} -/** This pass has no other pass as prerequisite. For each `TYPE` node that does not have a corresponding `TYPE_DECL` - * node, this pass creates a `TYPE_DECL` node. The `TYPE_DECL` is considered external. +/** This pass has no other pass as prerequisite. For each `TYPE` node that does not have a + * corresponding `TYPE_DECL` node, this pass creates a `TYPE_DECL` node. The `TYPE_DECL` is + * considered external. */ -class TypeDeclStubCreator(cpg: Cpg) extends CpgPass(cpg) { - - private var typeDeclFullNameToNode = Map[String, TypeDeclBase]() - - private def privateInit(): Unit = { - cpg.typeDecl - .foreach { typeDecl => - typeDeclFullNameToNode += typeDecl.fullName -> typeDecl - } - } - - override def run(dstGraph: DiffGraphBuilder): Unit = { - privateInit() - - cpg.typ - .filterNot(typ => typeDeclFullNameToNode.isDefinedAt(typ.fullName)) - .foreach { typ => - val newTypeDecl = TypeDeclStubCreator.createTypeDeclStub(typ.name, typ.fullName) - typeDeclFullNameToNode += typ.fullName -> newTypeDecl - dstGraph.addNode(newTypeDecl) - } - } - -} - -object TypeDeclStubCreator { - - def createTypeDeclStub( - name: String, - fullName: String, - isExternal: Boolean = true, - astParentType: String = NodeTypes.NAMESPACE_BLOCK, - astParentFullName: String = NamespaceTraversal.globalNamespaceName, - fileName: String = FileTraversal.UNKNOWN - ): NewTypeDecl = { - NewTypeDecl() - .name(name) - .fullName(fullName) - .isExternal(isExternal) - .inheritsFromTypeFullName(IndexedSeq.empty) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .filename(fileName) - } - -} +class TypeDeclStubCreator(cpg: Cpg) extends CpgPass(cpg): + + private var typeDeclFullNameToNode = Map[String, TypeDeclBase]() + + private def privateInit(): Unit = + cpg.typeDecl + .foreach { typeDecl => + typeDeclFullNameToNode += typeDecl.fullName -> typeDecl + } + + override def run(dstGraph: DiffGraphBuilder): Unit = + privateInit() + + cpg.typ + .filterNot(typ => typeDeclFullNameToNode.isDefinedAt(typ.fullName)) + .foreach { typ => + val newTypeDecl = TypeDeclStubCreator.createTypeDeclStub(typ.name, typ.fullName) + typeDeclFullNameToNode += typ.fullName -> newTypeDecl + dstGraph.addNode(newTypeDecl) + } +end TypeDeclStubCreator + +object TypeDeclStubCreator: + + def createTypeDeclStub( + name: String, + fullName: String, + isExternal: Boolean = true, + astParentType: String = NodeTypes.NAMESPACE_BLOCK, + astParentFullName: String = NamespaceTraversal.globalNamespaceName, + fileName: String = FileTraversal.UNKNOWN + ): NewTypeDecl = + NewTypeDecl() + .name(name) + .fullName(fullName) + .isExternal(isExternal) + .inheritsFromTypeFullName(IndexedSeq.empty) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .filename(fileName) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala index 3db4ec4d..c46d5ca0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala @@ -5,44 +5,44 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -class TypeUsagePass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { +class TypeUsagePass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = { - // Create REF edges from TYPE nodes to TYPE_DECL - linkToSingle( - cpg, - srcLabels = List(NodeTypes.TYPE), - dstNodeLabel = NodeTypes.TYPE_DECL, - edgeType = EdgeTypes.REF, - dstNodeMap = typeDeclFullNameToNode(cpg, _), - dstFullNameKey = PropertyNames.TYPE_DECL_FULL_NAME, - dstGraph, - None - ) + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create REF edges from TYPE nodes to TYPE_DECL + linkToSingle( + cpg, + srcLabels = List(NodeTypes.TYPE), + dstNodeLabel = NodeTypes.TYPE_DECL, + edgeType = EdgeTypes.REF, + dstNodeMap = typeDeclFullNameToNode(cpg, _), + dstFullNameKey = PropertyNames.TYPE_DECL_FULL_NAME, + dstGraph, + None + ) - // Create EVAL_TYPE edges from nodes of various types to TYPE - linkToSingle( - cpg, - srcLabels = List( - NodeTypes.METHOD_PARAMETER_IN, - NodeTypes.METHOD_PARAMETER_OUT, - NodeTypes.METHOD_RETURN, - NodeTypes.MEMBER, - NodeTypes.LITERAL, - NodeTypes.CALL, - NodeTypes.LOCAL, - NodeTypes.IDENTIFIER, - NodeTypes.BLOCK, - NodeTypes.METHOD_REF, - NodeTypes.TYPE_REF, - NodeTypes.UNKNOWN - ), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.EVAL_TYPE, - dstNodeMap = typeFullNameToNode(cpg, _), - dstFullNameKey = "TYPE_FULL_NAME", - dstGraph, - None - ) - } -} + // Create EVAL_TYPE edges from nodes of various types to TYPE + linkToSingle( + cpg, + srcLabels = List( + NodeTypes.METHOD_PARAMETER_IN, + NodeTypes.METHOD_PARAMETER_OUT, + NodeTypes.METHOD_RETURN, + NodeTypes.MEMBER, + NodeTypes.LITERAL, + NodeTypes.CALL, + NodeTypes.LOCAL, + NodeTypes.IDENTIFIER, + NodeTypes.BLOCK, + NodeTypes.METHOD_REF, + NodeTypes.TYPE_REF, + NodeTypes.UNKNOWN + ), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.EVAL_TYPE, + dstNodeMap = typeFullNameToNode(cpg, _), + dstFullNameKey = "TYPE_FULL_NAME", + dstGraph, + None + ) + end run +end TypeUsagePass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala index 42f285dc..9565d0ea 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala @@ -5,225 +5,219 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method, TypeDecl} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} import overflowdb.{NodeDb, NodeRef} import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -/** We compute the set of possible call-targets for each dynamic call, and add them as CALL edges to the graph, based on - * call.methodFullName, method.name and method.signature, the inheritance hierarchy and the AST of typedecls and - * methods. +/** We compute the set of possible call-targets for each dynamic call, and add them as CALL edges to + * the graph, based on call.methodFullName, method.name and method.signature, the inheritance + * hierarchy and the AST of typedecls and methods. * - * This pass intentionally ignores the vtable mechanism based on BINDING nodes but does check for an existing call edge - * before adding one. It assumes non-circular inheritance, on pain of endless recursion / stack overflow. + * This pass intentionally ignores the vtable mechanism based on BINDING nodes but does check for + * an existing call edge before adding one. It assumes non-circular inheritance, on pain of endless + * recursion / stack overflow. * - * Based on the algorithm by Jang, Dongseok & Tatlock, Zachary & Lerner, Sorin. (2014). SAFEDISPATCH: Securing C++ - * Virtual Calls from Memory Corruption Attacks. 10.14722/ndss.2014.23287. + * Based on the algorithm by Jang, Dongseok & Tatlock, Zachary & Lerner, Sorin. (2014). + * SAFEDISPATCH: Securing C++ Virtual Calls from Memory Corruption Attacks. + * 10.14722/ndss.2014.23287. */ -class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg) { - - import DynamicCallLinker._ - // Used to track potential method candidates for a given method fullname. Since our method full names contain the type - // decl we don't need to specify an addition map to wrap this in. LinkedHashSets are used here to preserve order in - // the best interest of reproducibility during debugging. - private val validM = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - // Used for dynamic programming as subtree's don't need to be recalculated later - private val subclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - private val superclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - // Used for O(1) lookups on methods that will work without indexManager - private val typeMap = mutable.Map.empty[String, TypeDecl] - // For linking loose method stubs that cannot be resolved by crawling parent types - private val methodMap = mutable.Map.empty[String, Method] - - private def initMaps(): Unit = { - cpg.typeDecl.foreach { typeDecl => - typeMap += (typeDecl.fullName -> typeDecl) - } - cpg.method - .filter(m => !m.name.startsWith("")) - .foreach { method => methodMap += (method.fullName -> method) } - } - - /** Main method of enhancement - to be implemented by child class - */ - override def run(dstGraph: DiffGraphBuilder): Unit = { - // Perform early stopping in the case of no virtual calls - if (!cpg.call.exists(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH)) { - return - } - initMaps() - // ValidM maps class C and method name N to the set of - // func ptrs implementing N for C and its subclasses - for ( - typeDecl <- cpg.typeDecl; - method <- typeDecl._methodViaAstOut - ) { - val methodName = method.fullName - val candidates = allSubclasses(typeDecl.fullName).flatMap { staticLookup(_, method) } - validM.put(methodName, candidates) - } - - subclassCache.clear() - - cpg.call.filter(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH).foreach { call => - try { - linkDynamicCall(call, dstGraph) - } catch { - case exception: Exception => - throw new RuntimeException(exception) - } - } - } - - /** Recursively returns all the sub-types of the given type declaration. Does account for circular hierarchies. - */ - private def allSubclasses(typDeclFullName: String): mutable.LinkedHashSet[String] = - inheritanceTraversal(typDeclFullName, subclassCache, inSuperDirection = false) - - /** Recursively returns all the super-types of the given type declaration. Does account for circular hierarchies. - */ - private def allSuperClasses(typDeclFullName: String): mutable.LinkedHashSet[String] = - inheritanceTraversal(typDeclFullName, superclassCache, inSuperDirection = true) - - private def inheritanceTraversal( - typDeclFullName: String, - cache: mutable.Map[String, mutable.LinkedHashSet[String]], - inSuperDirection: Boolean - ): mutable.LinkedHashSet[String] = { - cache.get(typDeclFullName) match { - case Some(superClasses) => superClasses - case None => - val totalSuperclasses = (cpg.typeDecl - .fullNameExact(typDeclFullName) - .headOption match { - case Some(curr) => inheritTraversal(curr, inSuperDirection) - case None => mutable.LinkedHashSet.empty - }).map(_.fullName) - cache.put(typDeclFullName, totalSuperclasses) - totalSuperclasses - } - } - - private def inheritTraversal( - cur: TypeDecl, - inSuperDirection: Boolean, - visitedNodes: mutable.LinkedHashSet[TypeDecl] = mutable.LinkedHashSet.empty - ): mutable.LinkedHashSet[TypeDecl] = { - if (visitedNodes.contains(cur)) return visitedNodes - visitedNodes.addOne(cur) - - (if (inSuperDirection) cpg.typeDecl.fullNameExact(cur.fullName).flatMap(_.inheritsFromOut.referencedTypeDecl) - else cpg.typ.fullNameExact(cur.fullName).flatMap(_.inheritsFromIn)) - .collectAll[TypeDecl] - .to(mutable.LinkedHashSet) match { - case classesToEval if classesToEval.isEmpty => visitedNodes - case classesToEval => - classesToEval.flatMap(t => inheritTraversal(t, inSuperDirection, visitedNodes)) - visitedNodes - } - } - - /** Returns the method from a sub-class implementing a method for the given subclass. - */ - private def staticLookup(subclass: String, method: Method): Option[String] = { - typeMap.get(subclass) match { - case Some(sc) => - sc._methodViaAstOut - .nameExact(method.name) - .and(_.signatureExact(method.signature)) - .map(_.fullName) - .headOption - case None => None - } - } - - private def resolveCallInSuperClasses(call: Call): Boolean = { - if (!call.methodFullName.contains(":")) return false - def split(str: String, n: Int) = (str.take(n), str.drop(n + 1)) - val (fullName, signature) = split(call.methodFullName, call.methodFullName.lastIndexOf(":")) - val typeDeclFullName = fullName.replace(s".${call.name}", "") - val candidateInheritedMethods = - cpg.typeDecl - .fullNameExact(allSuperClasses(typeDeclFullName).toIndexedSeq: _*) - .astChildren - .isMethod - .name(call.name) - .and(_.signatureExact(signature)) - .fullName - .l - if (candidateInheritedMethods.nonEmpty) { - validM.put( - call.methodFullName, - validM.getOrElse(call.methodFullName, mutable.LinkedHashSet.empty) ++ mutable.LinkedHashSet.from( - candidateInheritedMethods +class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg): + + import DynamicCallLinker.* + // Used to track potential method candidates for a given method fullname. Since our method full names contain the type + // decl we don't need to specify an addition map to wrap this in. LinkedHashSets are used here to preserve order in + // the best interest of reproducibility during debugging. + private val validM = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + // Used for dynamic programming as subtree's don't need to be recalculated later + private val subclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + private val superclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + // Used for O(1) lookups on methods that will work without indexManager + private val typeMap = mutable.Map.empty[String, TypeDecl] + // For linking loose method stubs that cannot be resolved by crawling parent types + private val methodMap = mutable.Map.empty[String, Method] + + private def initMaps(): Unit = + cpg.typeDecl.foreach { typeDecl => + typeMap += (typeDecl.fullName -> typeDecl) + } + cpg.method + .filter(m => !m.name.startsWith("")) + .foreach { method => methodMap += (method.fullName -> method) } + + /** Main method of enhancement - to be implemented by child class + */ + override def run(dstGraph: DiffGraphBuilder): Unit = + // Perform early stopping in the case of no virtual calls + if !cpg.call.exists(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH) then + return + initMaps() + // ValidM maps class C and method name N to the set of + // func ptrs implementing N for C and its subclasses + for + typeDecl <- cpg.typeDecl; + method <- typeDecl._methodViaAstOut + do + val methodName = method.fullName + val candidates = allSubclasses(typeDecl.fullName).flatMap { staticLookup(_, method) } + validM.put(methodName, candidates) + + subclassCache.clear() + + cpg.call.filter(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH).foreach { call => + try + linkDynamicCall(call, dstGraph) + catch + case exception: Exception => + throw new RuntimeException(exception) + } + end run + + /** Recursively returns all the sub-types of the given type declaration. Does account for + * circular hierarchies. + */ + private def allSubclasses(typDeclFullName: String): mutable.LinkedHashSet[String] = + inheritanceTraversal(typDeclFullName, subclassCache, inSuperDirection = false) + + /** Recursively returns all the super-types of the given type declaration. Does account for + * circular hierarchies. + */ + private def allSuperClasses(typDeclFullName: String): mutable.LinkedHashSet[String] = + inheritanceTraversal(typDeclFullName, superclassCache, inSuperDirection = true) + + private def inheritanceTraversal( + typDeclFullName: String, + cache: mutable.Map[String, mutable.LinkedHashSet[String]], + inSuperDirection: Boolean + ): mutable.LinkedHashSet[String] = + cache.get(typDeclFullName) match + case Some(superClasses) => superClasses + case None => + val totalSuperclasses = (cpg.typeDecl + .fullNameExact(typDeclFullName) + .headOption match + case Some(curr) => inheritTraversal(curr, inSuperDirection) + case None => mutable.LinkedHashSet.empty + ).map(_.fullName) + cache.put(typDeclFullName, totalSuperclasses) + totalSuperclasses + + private def inheritTraversal( + cur: TypeDecl, + inSuperDirection: Boolean, + visitedNodes: mutable.LinkedHashSet[TypeDecl] = mutable.LinkedHashSet.empty + ): mutable.LinkedHashSet[TypeDecl] = + if visitedNodes.contains(cur) then return visitedNodes + visitedNodes.addOne(cur) + + (if inSuperDirection then + cpg.typeDecl.fullNameExact(cur.fullName).flatMap(_.inheritsFromOut.referencedTypeDecl) + else cpg.typ.fullNameExact(cur.fullName).flatMap(_.inheritsFromIn)) + .collectAll[TypeDecl] + .to(mutable.LinkedHashSet) match + case classesToEval if classesToEval.isEmpty => visitedNodes + case classesToEval => + classesToEval.flatMap(t => inheritTraversal(t, inSuperDirection, visitedNodes)) + visitedNodes + + /** Returns the method from a sub-class implementing a method for the given subclass. + */ + private def staticLookup(subclass: String, method: Method): Option[String] = + typeMap.get(subclass) match + case Some(sc) => + sc._methodViaAstOut + .nameExact(method.name) + .and(_.signatureExact(method.signature)) + .map(_.fullName) + .headOption + case None => None + + private def resolveCallInSuperClasses(call: Call): Boolean = + if !call.methodFullName.contains(":") then return false + def split(str: String, n: Int) = (str.take(n), str.drop(n + 1)) + val (fullName, signature) = split(call.methodFullName, call.methodFullName.lastIndexOf(":")) + val typeDeclFullName = fullName.replace(s".${call.name}", "") + val candidateInheritedMethods = + cpg.typeDecl + .fullNameExact(allSuperClasses(typeDeclFullName).toIndexedSeq*) + .astChildren + .isMethod + .name(call.name) + .and(_.signatureExact(signature)) + .fullName + .l + if candidateInheritedMethods.nonEmpty then + validM.put( + call.methodFullName, + validM.getOrElse( + call.methodFullName, + mutable.LinkedHashSet.empty + ) ++ mutable.LinkedHashSet.from( + candidateInheritedMethods + ) + ) + true + else + false + end resolveCallInSuperClasses + + private def linkDynamicCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + // This call linker requires a method full name entry + if call.methodFullName.equals("") || call.methodFullName.equals( + DynamicCallUnknownFullName + ) + then return + // Support for overriding + resolveCallInSuperClasses(call) + + validM.get(call.methodFullName) match + case Some(tgts) => + val callsOut = call.callOut.fullName.toSetImmutable + val tgtMs = tgts + .flatMap(destMethod => + if cpg.graph.indexManager.isIndexed(PropertyNames.FULL_NAME) then + methodFullNameToNode(destMethod) + else + cpg.method.fullNameExact(destMethod).headOption + ) + .toSet + // Non-overridden methods linked as external stubs should be excluded if they are detected + val (externalMs, internalMs) = tgtMs.partition(_.isExternal) + (if externalMs.nonEmpty && internalMs.nonEmpty then internalMs else tgtMs) + .foreach { tgtM => + if !callsOut.contains(tgtM.fullName) then + dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) + else + fallbackToStaticResolution(call, dstGraph) + } + case None => + fallbackToStaticResolution(call, dstGraph) + end match + end linkDynamicCall + + /** In the case where the method isn't an internal method and cannot be resolved by crawling + * TYPE_DECL nodes it can be resolved from the map of external methods. + */ + private def fallbackToStaticResolution(call: Call, dstGraph: DiffGraphBuilder): Unit = + methodMap.get(call.methodFullName) match + case Some(tgtM) => dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) + case None => printLinkingError(call) + + private def nodesWithFullName(x: String): Iterable[NodeRef[? <: NodeDb]] = + cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + + private def methodFullNameToNode(x: String): Option[Method] = + nodesWithFullName(x).collectFirst { case x: Method => x } + + @inline + private def printLinkingError(call: Call): Unit = + logger.debug( + s"Unable to link dynamic CALL with METHOD_FULL_NAME ${call.methodFullName} and context: " + + s"${call.code} @ line ${call.lineNumber}" ) - ) - true - } else { - false - } - } - - private def linkDynamicCall(call: Call, dstGraph: DiffGraphBuilder): Unit = { - // This call linker requires a method full name entry - if (call.methodFullName.equals("") || call.methodFullName.equals(DynamicCallUnknownFullName)) return - // Support for overriding - resolveCallInSuperClasses(call) - - validM.get(call.methodFullName) match { - case Some(tgts) => - val callsOut = call.callOut.fullName.toSetImmutable - val tgtMs = tgts - .flatMap(destMethod => - if (cpg.graph.indexManager.isIndexed(PropertyNames.FULL_NAME)) { - methodFullNameToNode(destMethod) - } else { - cpg.method.fullNameExact(destMethod).headOption - } - ) - .toSet - // Non-overridden methods linked as external stubs should be excluded if they are detected - val (externalMs, internalMs) = tgtMs.partition(_.isExternal) - (if (externalMs.nonEmpty && internalMs.nonEmpty) internalMs else tgtMs) - .foreach { tgtM => - if (!callsOut.contains(tgtM.fullName)) { - dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) - } else { - fallbackToStaticResolution(call, dstGraph) - } - } - case None => - fallbackToStaticResolution(call, dstGraph) - } - } - - /** In the case where the method isn't an internal method and cannot be resolved by crawling TYPE_DECL nodes it can be - * resolved from the map of external methods. - */ - private def fallbackToStaticResolution(call: Call, dstGraph: DiffGraphBuilder): Unit = { - methodMap.get(call.methodFullName) match { - case Some(tgtM) => dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) - case None => printLinkingError(call) - } - } - - private def nodesWithFullName(x: String): Iterable[NodeRef[_ <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala - - private def methodFullNameToNode(x: String): Option[Method] = - nodesWithFullName(x).collectFirst { case x: Method => x } - - @inline - private def printLinkingError(call: Call): Unit = { - logger.debug( - s"Unable to link dynamic CALL with METHOD_FULL_NAME ${call.methodFullName} and context: " + - s"${call.code} @ line ${call.lineNumber}" - ) - } -} - -object DynamicCallLinker { - private val logger: Logger = LoggerFactory.getLogger(classOf[DynamicCallLinker]) -} +end DynamicCallLinker + +object DynamicCallLinker: + private val logger: Logger = LoggerFactory.getLogger(classOf[DynamicCallLinker]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala index 5cbc3b97..64b32a42 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala @@ -1,27 +1,24 @@ package io.appthreat.x2cpg.passes.callgraph import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated._ +import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.passes.CpgPass import io.appthreat.x2cpg.utils.LinkingUtil -/** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends which do not provide - * method stubs and type decl stubs. +/** This pass has MethodStubCreator and TypeDeclStubCreator as prerequisite for language frontends + * which do not provide method stubs and type decl stubs. */ -class MethodRefLinker(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { +class MethodRefLinker(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = { - // Create REF edges from METHOD_REFs to METHOD - linkToSingle( - cpg, - srcLabels = List(NodeTypes.METHOD_REF), - dstNodeLabel = NodeTypes.METHOD, - edgeType = EdgeTypes.REF, - dstNodeMap = methodFullNameToNode(cpg, _), - dstFullNameKey = PropertyNames.METHOD_FULL_NAME, - dstGraph, - None - ) - } - -} + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create REF edges from METHOD_REFs to METHOD + linkToSingle( + cpg, + srcLabels = List(NodeTypes.METHOD_REF), + dstNodeLabel = NodeTypes.METHOD, + edgeType = EdgeTypes.REF, + dstNodeMap = methodFullNameToNode(cpg, _), + dstFullNameKey = PropertyNames.METHOD_FULL_NAME, + dstGraph, + None + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala index 58b233a1..4fca6e32 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala @@ -3,28 +3,24 @@ package io.appthreat.x2cpg.passes.callgraph import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.jIteratortoTraversal /** Link remaining unlinked calls to methods only by their name (not full name) * @param cpg * the target code property graph. */ -class NaiveCallLinker(cpg: Cpg) extends CpgPass(cpg) { +class NaiveCallLinker(cpg: Cpg) extends CpgPass(cpg): - override def run(dstGraph: DiffGraphBuilder): Unit = { - val methodNameToNode = cpg.method.toList.groupBy(_.name) - def calls = cpg.call.filter(_.outE(EdgeTypes.CALL).isEmpty) - for { - call <- calls - methods <- methodNameToNode.get(call.name) - method <- methods - } { - dstGraph.addEdge(call, method, EdgeTypes.CALL) - // If we can only find one name with the exact match then we can semi-confidently set it as the full name - if (methods.sizeIs == 1) - dstGraph.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, method.fullName) - } - } - -} + override def run(dstGraph: DiffGraphBuilder): Unit = + val methodNameToNode = cpg.method.toList.groupBy(_.name) + def calls = cpg.call.filter(_.outE(EdgeTypes.CALL).isEmpty) + for + call <- calls + methods <- methodNameToNode.get(call.name) + method <- methods + do + dstGraph.addEdge(call, method, EdgeTypes.CALL) + // If we can only find one name with the exact match then we can semi-confidently set it as the full name + if methods.sizeIs == 1 then + dstGraph.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, method.fullName) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala index 175e73b8..36b93da0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala @@ -1,64 +1,56 @@ package io.appthreat.x2cpg.passes.callgraph import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable -class StaticCallLinker(cpg: Cpg) extends CpgPass(cpg) { - - import StaticCallLinker._ - private val methodFullNameToNode = mutable.Map.empty[String, List[Method]] - - override def run(dstGraph: DiffGraphBuilder): Unit = { - - cpg.method.foreach { method => - methodFullNameToNode.updateWith(method.fullName) { - case Some(l) => Some(method :: l) - case None => Some(List(method)) - } - } - - cpg.call.foreach { call => - try { - linkCall(call, dstGraph) - } catch { - case exception: Exception => - throw new RuntimeException(exception) - } - } - } - - private def linkCall(call: Call, dstGraph: DiffGraphBuilder): Unit = { - call.dispatchType match { - case DispatchTypes.STATIC_DISPATCH | DispatchTypes.INLINED => - linkStaticCall(call, dstGraph) - case DispatchTypes.DYNAMIC_DISPATCH => - // Do nothing - case _ => logger.debug(s"Unknown dispatch type on dynamic CALL ${call.code}") - } - } - - private def linkStaticCall(call: Call, dstGraph: DiffGraphBuilder): Unit = { - val resolvedMethodOption = methodFullNameToNode.get(call.methodFullName) - if (resolvedMethodOption.isDefined) { - resolvedMethodOption.get.foreach { dst => - dstGraph.addEdge(call, dst, EdgeTypes.CALL) - } - } else { - logger.debug( - s"Unable to link static CALL with METHOD_FULL_NAME ${call.methodFullName}, NAME ${call.name}, " + - s"SIGNATURE ${call.signature}, CODE ${call.code}" - ) - } - } - -} - -object StaticCallLinker { - private val logger: Logger = LoggerFactory.getLogger(classOf[StaticCallLinker]) -} +class StaticCallLinker(cpg: Cpg) extends CpgPass(cpg): + + import StaticCallLinker.* + private val methodFullNameToNode = mutable.Map.empty[String, List[Method]] + + override def run(dstGraph: DiffGraphBuilder): Unit = + + cpg.method.foreach { method => + methodFullNameToNode.updateWith(method.fullName) { + case Some(l) => Some(method :: l) + case None => Some(List(method)) + } + } + + cpg.call.foreach { call => + try + linkCall(call, dstGraph) + catch + case exception: Exception => + throw new RuntimeException(exception) + } + + private def linkCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + call.dispatchType match + case DispatchTypes.STATIC_DISPATCH | DispatchTypes.INLINED => + linkStaticCall(call, dstGraph) + case DispatchTypes.DYNAMIC_DISPATCH => + // Do nothing + case _ => logger.debug(s"Unknown dispatch type on dynamic CALL ${call.code}") + + private def linkStaticCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + val resolvedMethodOption = methodFullNameToNode.get(call.methodFullName) + if resolvedMethodOption.isDefined then + resolvedMethodOption.get.foreach { dst => + dstGraph.addEdge(call, dst, EdgeTypes.CALL) + } + else + logger.debug( + s"Unable to link static CALL with METHOD_FULL_NAME ${call.methodFullName}, NAME ${call.name}, " + + s"SIGNATURE ${call.signature}, CODE ${call.code}" + ) +end StaticCallLinker + +object StaticCallLinker: + private val logger: Logger = LoggerFactory.getLogger(classOf[StaticCallLinker]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala index 6e6f5123..9ba4af18 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala @@ -3,7 +3,7 @@ package io.appthreat.x2cpg.passes.controlflow import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.passes.ConcurrentWriterCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator /** A pass that creates control flow graphs from abstract syntax trees. @@ -11,17 +11,15 @@ import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator * Control flow graphs can be calculated independently per method. Therefore, we inherit from * `ConcurrentWriterCpgPass`. * - * Note: the version of OverflowDB that we currently use as a storage backend does not assign ids to edges and this - * pass only creates edges at the moment. Therefore, we currently do without key pools. + * Note: the version of OverflowDB that we currently use as a storage backend does not assign ids + * to edges and this pass only creates edges at the moment. Therefore, we currently do without key + * pools. */ -class CfgCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Method](cpg) { +class CfgCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Method](cpg): - override def generateParts(): Array[Method] = cpg.method.toArray + override def generateParts(): Array[Method] = cpg.method.toArray - override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = { - val localDiff = new DiffGraphBuilder - new CfgCreator(method, localDiff).run() - diffGraph.absorb(localDiff) - } - -} + override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = + val localDiff = new DiffGraphBuilder + new CfgCreator(method, localDiff).run() + diffGraph.absorb(localDiff) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala index 3ad493eb..4f9065db 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala @@ -7,27 +7,27 @@ import org.slf4j.LoggerFactory /** A control flow graph that is under construction, consisting of: * * @param entryNode - * the control flow graph's first node, that is, the node to which a CFG that appends this CFG should attach itself - * to. + * the control flow graph's first node, that is, the node to which a CFG that appends this CFG + * should attach itself to. * @param edges * control flow edges between nodes of the code property graph. * @param fringe - * nodes of the CFG for which an outgoing edge type is already known but the destination node is not. These nodes are - * connected when another CFG is appended to this CFG. + * nodes of the CFG for which an outgoing edge type is already known but the destination node is + * not. These nodes are connected when another CFG is appended to this CFG. * - * In addition to these three core building blocks, we store labels and jump statements that have not been resolved and - * may be resolvable as parent sub trees or sibblings are translated. + * In addition to these three core building blocks, we store labels and jump statements that have + * not been resolved and may be resolvable as parent sub trees or sibblings are translated. * * @param labeledNodes * labels contained in the abstract syntax tree from which this CPG was generated * @param caseLabels * labels beginning with "case" * @param breaks - * unresolved breaks collected along the way together with an integer value which indicates the number of loop/switch - * levels to break + * unresolved breaks collected along the way together with an integer value which indicates the + * number of loop/switch levels to break * @param continues - * unresolved continues collected along the way together with an integer value which indicates the number of - * loop/switch levels after which to continue + * unresolved continues collected along the way together with an integer value which indicates + * the number of loop/switch levels after which to continue * @param jumpsToLabel * unresolved gotos, labeled break and labeld continues collected along the way */ @@ -40,158 +40,154 @@ case class Cfg( continues: List[(CfgNode, Int)] = List(), caseLabels: List[CfgNode] = List(), jumpsToLabel: List[(CfgNode, String)] = List() -) { - - import Cfg._ - - /** Create a new CFG in which `other` is appended to this CFG. All nodes of the fringe are connected to `other`'s - * entry node and the new fringe is `other`'s fringe. The diffgraphs, jumps, and labels are the sum of those present - * in `this` and `other`. - */ - def ++(other: Cfg): Cfg = { - if (other == Cfg.empty) { - this - } else if (this == Cfg.empty) { - other - } else { - this.copy( - fringe = other.fringe, - edges = this.edges ++ other.edges ++ - edgesFromFringeTo(this, other.entryNode), - jumpsToLabel = this.jumpsToLabel ++ other.jumpsToLabel, - labeledNodes = this.labeledNodes ++ other.labeledNodes, - breaks = this.breaks ++ other.breaks, - continues = this.continues ++ other.continues, - caseLabels = this.caseLabels ++ other.caseLabels - ) - } - } - - def withFringeEdgeType(cfgEdgeType: CfgEdgeType): Cfg = { - this.copy(fringe = fringe.map { case (x, _) => (x, cfgEdgeType) }) - } - - /** Upon completing traversal of the abstract syntax tree, this method creates CFG edges between jumps like gotos, - * labeled breaks, labeled continues and respective labels. - */ - def withResolvedJumpToLabel(): Cfg = { - val edges = jumpsToLabel.flatMap { - case (jumpToLabel, label) if label != "*" => - labeledNodes.get(label) match { - case Some(labeledNode) => - // TODO set edge type of Always once the backend - // supports it - Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) - case None => - logger.debug("Unable to wire jump statement. Missing label {}.", label) - None +): + + import Cfg.* + + /** Create a new CFG in which `other` is appended to this CFG. All nodes of the fringe are + * connected to `other`'s entry node and the new fringe is `other`'s fringe. The diffgraphs, + * jumps, and labels are the sum of those present in `this` and `other`. + */ + def ++(other: Cfg): Cfg = + if other == Cfg.empty then + this + else if this == Cfg.empty then + other + else + this.copy( + fringe = other.fringe, + edges = this.edges ++ other.edges ++ + edgesFromFringeTo(this, other.entryNode), + jumpsToLabel = this.jumpsToLabel ++ other.jumpsToLabel, + labeledNodes = this.labeledNodes ++ other.labeledNodes, + breaks = this.breaks ++ other.breaks, + continues = this.continues ++ other.continues, + caseLabels = this.caseLabels ++ other.caseLabels + ) + + def withFringeEdgeType(cfgEdgeType: CfgEdgeType): Cfg = + this.copy(fringe = fringe.map { case (x, _) => (x, cfgEdgeType) }) + + /** Upon completing traversal of the abstract syntax tree, this method creates CFG edges between + * jumps like gotos, labeled breaks, labeled continues and respective labels. + */ + def withResolvedJumpToLabel(): Cfg = + val edges = jumpsToLabel.flatMap { + case (jumpToLabel, label) if label != "*" => + labeledNodes.get(label) match + case Some(labeledNode) => + // TODO set edge type of Always once the backend + // supports it + Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) + case None => + logger.debug("Unable to wire jump statement. Missing label {}.", label) + None + case (jumpToLabel, _) => + // We come here for: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html + // For such GOTOs we cannot statically determine the target label. As a quick + // hack we simply put edges to all labels found. This might be an over-taint. + labeledNodes.flatMap { case (_, labeledNode) => + Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) + } } - case (jumpToLabel, _) => - // We come here for: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html - // For such GOTOs we cannot statically determine the target label. As a quick - // hack we simply put edges to all labels found. This might be an over-taint. - labeledNodes.flatMap { case (_, labeledNode) => - Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) + this.copy(edges = this.edges ++ edges) + end withResolvedJumpToLabel +end Cfg + +case class CfgEdge(src: CfgNode, dst: CfgNode, edgeType: CfgEdgeType) + +object Cfg: + + private val logger = LoggerFactory.getLogger(getClass) + + def from(cfgs: Cfg*): Cfg = + Cfg( + jumpsToLabel = cfgs.map(_.jumpsToLabel).reduceOption((x, y) => x ++ y).getOrElse(List()), + breaks = cfgs.map(_.breaks).reduceOption((x, y) => x ++ y).getOrElse(List()), + continues = cfgs.map(_.continues).reduceOption((x, y) => x ++ y).getOrElse(List()), + caseLabels = cfgs.map(_.caseLabels).reduceOption((x, y) => x ++ y).getOrElse(List()), + labeledNodes = cfgs.map(_.labeledNodes).reduceOption((x, y) => x ++ y).getOrElse(Map()) + ) + + /** The safe "null" Cfg. + */ + val empty: Cfg = new Cfg() + + trait CfgEdgeType + object TrueEdge extends CfgEdgeType: + override def toString: String = "TrueEdge" + object FalseEdge extends CfgEdgeType: + override def toString: String = "FalseEdge" + object AlwaysEdge extends CfgEdgeType: + override def toString: String = "AlwaysEdge" + object CaseEdge extends CfgEdgeType: + override def toString: String = "CaseEdge" + + /** Create edges from all nodes of cfg's fringe to `node`. + */ + def edgesFromFringeTo(cfg: Cfg, node: Option[CfgNode]): List[CfgEdge] = + edgesFromFringeTo(cfg.fringe, node) + + /** Create edges from all nodes of cfg's fringe to `node`, ignoring fringe edge types and using + * `cfgEdgeType` instead. + */ + def edgesFromFringeTo( + cfg: Cfg, + node: Option[CfgNode], + cfgEdgeType: CfgEdgeType + ): List[CfgEdge] = + edges(cfg.fringe.map(_._1), node, cfgEdgeType) + + /** Create edges from a list (node, cfgEdgeType) pairs to `node` + */ + def edgesFromFringeTo( + fringeElems: List[(CfgNode, CfgEdgeType)], + node: Option[CfgNode] + ): List[CfgEdge] = + fringeElems.flatMap { case (sourceNode, cfgEdgeType) => + node.map { dstNode => + CfgEdge(sourceNode, dstNode, cfgEdgeType) + } } - } - this.copy(edges = this.edges ++ edges) - } -} + /** Create edges of given type from a list of source nodes to a destination node + */ + def edges( + sources: List[CfgNode], + dstNode: Option[CfgNode], + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + edgesToMultiple(sources, dstNode.toList, cfgEdgeType) + + def singleEdge( + source: CfgNode, + destination: CfgNode, + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + edgesToMultiple(List(source), List(destination), cfgEdgeType) + + /** Create edges of given type from all nodes in `sources` to `node`. + */ + def edgesToMultiple( + sources: List[CfgNode], + destinations: List[CfgNode], + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + sources.flatMap { l => + destinations.map { n => + CfgEdge(l, n, cfgEdgeType) + } + } -case class CfgEdge(src: CfgNode, dst: CfgNode, edgeType: CfgEdgeType) + def takeCurrentLevel(nodesWithLevel: List[(CfgNode, Int)]): List[CfgNode] = + nodesWithLevel.collect { + case (node, level) if level == 1 => + node + } -object Cfg { - - private val logger = LoggerFactory.getLogger(getClass) - - def from(cfgs: Cfg*): Cfg = { - Cfg( - jumpsToLabel = cfgs.map(_.jumpsToLabel).reduceOption((x, y) => x ++ y).getOrElse(List()), - breaks = cfgs.map(_.breaks).reduceOption((x, y) => x ++ y).getOrElse(List()), - continues = cfgs.map(_.continues).reduceOption((x, y) => x ++ y).getOrElse(List()), - caseLabels = cfgs.map(_.caseLabels).reduceOption((x, y) => x ++ y).getOrElse(List()), - labeledNodes = cfgs.map(_.labeledNodes).reduceOption((x, y) => x ++ y).getOrElse(Map()) - ) - } - - /** The safe "null" Cfg. - */ - val empty: Cfg = new Cfg() - - trait CfgEdgeType - object TrueEdge extends CfgEdgeType { - override def toString: String = "TrueEdge" - } - object FalseEdge extends CfgEdgeType { - override def toString: String = "FalseEdge" - } - object AlwaysEdge extends CfgEdgeType { - override def toString: String = "AlwaysEdge" - } - object CaseEdge extends CfgEdgeType { - override def toString: String = "CaseEdge" - } - - /** Create edges from all nodes of cfg's fringe to `node`. - */ - def edgesFromFringeTo(cfg: Cfg, node: Option[CfgNode]): List[CfgEdge] = { - edgesFromFringeTo(cfg.fringe, node) - } - - /** Create edges from all nodes of cfg's fringe to `node`, ignoring fringe edge types and using `cfgEdgeType` instead. - */ - def edgesFromFringeTo(cfg: Cfg, node: Option[CfgNode], cfgEdgeType: CfgEdgeType): List[CfgEdge] = { - edges(cfg.fringe.map(_._1), node, cfgEdgeType) - } - - /** Create edges from a list (node, cfgEdgeType) pairs to `node` - */ - def edgesFromFringeTo(fringeElems: List[(CfgNode, CfgEdgeType)], node: Option[CfgNode]): List[CfgEdge] = { - fringeElems.flatMap { case (sourceNode, cfgEdgeType) => - node.map { dstNode => - CfgEdge(sourceNode, dstNode, cfgEdgeType) - } - } - } - - /** Create edges of given type from a list of source nodes to a destination node - */ - def edges(sources: List[CfgNode], dstNode: Option[CfgNode], cfgEdgeType: CfgEdgeType = AlwaysEdge): List[CfgEdge] = { - edgesToMultiple(sources, dstNode.toList, cfgEdgeType) - } - - def singleEdge(source: CfgNode, destination: CfgNode, cfgEdgeType: CfgEdgeType = AlwaysEdge): List[CfgEdge] = { - edgesToMultiple(List(source), List(destination), cfgEdgeType) - } - - /** Create edges of given type from all nodes in `sources` to `node`. - */ - def edgesToMultiple( - sources: List[CfgNode], - destinations: List[CfgNode], - cfgEdgeType: CfgEdgeType = AlwaysEdge - ): List[CfgEdge] = { - - sources.flatMap { l => - destinations.map { n => - CfgEdge(l, n, cfgEdgeType) - } - } - } - - def takeCurrentLevel(nodesWithLevel: List[(CfgNode, Int)]): List[CfgNode] = { - nodesWithLevel.collect { - case (node, level) if level == 1 => - node - } - } - - def reduceAndFilterLevel(nodesWithLevel: List[(CfgNode, Int)]): List[(CfgNode, Int)] = { - nodesWithLevel.collect { - case (node, level) if level != 1 => - (node, level - 1) - } - } - -} + def reduceAndFilterLevel(nodesWithLevel: List[(CfgNode, Int)]): List[(CfgNode, Int)] = + nodesWithLevel.collect { + case (node, level) if level != 1 => + (node, level - 1) + } +end Cfg diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala index 0f964222..7f825739 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala @@ -2,16 +2,21 @@ package io.appthreat.x2cpg.passes.controlflow.cfgcreation import io.appthreat.x2cpg.passes.controlflow.cfgcreation.Cfg.CfgEdgeType import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, EdgeTypes, Operators} +import io.shiftleft.codepropertygraph.generated.{ + ControlStructureTypes, + DispatchTypes, + EdgeTypes, + Operators +} import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate.DiffGraphBuilder /** Translation of abstract syntax trees into control flow graphs * - * The problem of translating an abstract syntax tree into a corresponding control flow graph can be formulated as a - * recursive problem in which sub trees of the syntax tree are translated and their corresponding control flow graphs - * are connected according to the control flow semantics of the root node. For example, consider the abstract syntax - * tree for an if-statement: + * The problem of translating an abstract syntax tree into a corresponding control flow graph can + * be formulated as a recursive problem in which sub trees of the syntax tree are translated and + * their corresponding control flow graphs are connected according to the control flow semantics of + * the root node. For example, consider the abstract syntax tree for an if-statement: * {{{ * ( if ) * / \ @@ -19,8 +24,9 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder * / \ / \ * x 10 x 1 * }}} - * This tree can be translated into a control flow graph, by translating the sub tree rooted in `x < 10` and that of - * `x+= 1` and connecting their control flow graphs according to the semantics of `if`: + * This tree can be translated into a control flow graph, by translating the sub tree rooted in `x + * < 10` and that of `x+= 1` and connecting their control flow graphs according to the semantics of + * `if`: * {{{ * [x < 10]---- * |t f| @@ -28,582 +34,591 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder * | * }}} * - * The semantics of if dictate that the first sub tree to the left is a condition, which is connected to the CFG of the - * second sub tree - the body of the if statement - via a control flow edge with the `true` label (indicated in the - * illustration by `t`), and to the CFG of any follow-up code via a `false` edge (indicated by `f`). + * The semantics of if dictate that the first sub tree to the left is a condition, which is + * connected to the CFG of the second sub tree - the body of the if statement - via a control flow + * edge with the `true` label (indicated in the illustration by `t`), and to the CFG of any + * follow-up code via a `false` edge (indicated by `f`). * - * A problem that becomes immediately apparent in the illustration is that the result of translating a sub tree may - * leave us with edges for which a source node is known but the destination node depends on parents or siblings that - * were not considered in the translation. For example, we know that an outgoing edge from [x<10] must exist, but we do - * not yet know where it should lead. We refer to the set of nodes of the control flow graph with outgoing edges for + * A problem that becomes immediately apparent in the illustration is that the result of + * translating a sub tree may leave us with edges for which a source node is known but the + * destination node depends on parents or siblings that were not considered in the translation. For + * example, we know that an outgoing edge from [x<10] must exist, but we do not yet know where it + * should lead. We refer to the set of nodes of the control flow graph with outgoing edges for * which the destination node is yet to be determined as the "fringe" of the control flow graph. */ -class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder) { - - import io.appthreat.x2cpg.passes.controlflow.cfgcreation.Cfg.* - import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator.* - - /** Control flow graph definitions often feature a designated entry and exit node for each method. While these nodes - * are no-ops from a computational point of view, they are useful to guarantee that a method has exactly one entry - * and one exit. - * - * For the CPG-based control flow graph, we do not need to introduce fake entry and exit node. Instead, we can use - * the METHOD and METHOD_RETURN nodes as entry and exit nodes respectively. Note that METHOD_RETURN nodes are the - * nodes representing formal return parameters, of which there exists exactly one per method. - */ - private val exitNode: MethodReturn = entryNode.methodReturn - - /** We return the CFG as a sequence of Diff Graphs that is calculated by first obtaining the CFG for the method and - * then resolving gotos. - */ - def run(): Unit = { - cfgForMethod(entryNode).withResolvedJumpToLabel().edges.foreach { edge => - // TODO: we are ignoring edge.edgeType because the - // CFG spec doesn't define an edge type at the moment - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CFG) - } - } - - /** Conversion of a method to a CFG, showing the decomposition of the control flow graph generation problem into that - * of translating sub trees according to the node type. In the particular case of a method, the CFG is obtained by - * creating a CFG containing the single method node and a fringe containing the node and an outgoing AlwaysEdge, to - * the CFG obtained by translating child CFGs one by one and appending them. - */ - private def cfgForMethod(node: Method): Cfg = - cfgForSingleNode(node) ++ cfgForChildren(node) - - /** For any single AST node, we can construct a CFG containing that single node by setting it as the entry node and - * placing it in the fringe. - */ - private def cfgForSingleNode(node: CfgNode): Cfg = - Cfg(entryNode = Option(node), fringe = List((node, AlwaysEdge))) - - /** The CFG for all children is obtained by translating child ASTs one by one from left to right and appending them. - */ - private def cfgForChildren(node: AstNode): Cfg = - node.astChildren.l.map(cfgFor).reduceOption((x, y) => x ++ y).getOrElse(Cfg.empty) - - /** Returns true if this node is a child to some `try` control structure, false if otherwise. - */ - private def withinATryBlock(x: AstNode): Boolean = - x.inAst.isControlStructure.exists(_.controlStructureType == ControlStructureTypes.TRY) - - /** This method dispatches AST nodes by type and calls corresponding conversion methods. - */ - protected def cfgFor(node: AstNode): Cfg = - node match { - case _: Method | _: MethodParameterIn | _: Modifier | _: Local | _: TypeDecl | _: Member => - Cfg.empty - case _: MethodRef | _: TypeRef | _: MethodReturn => - cfgForSingleNode(node.asInstanceOf[CfgNode]) - case controlStructure: ControlStructure => - cfgForControlStructure(controlStructure) - case jumpTarget: JumpTarget => - cfgForJumpTarget(jumpTarget) - case ret: Return if withinATryBlock(ret) => - cfgForReturn(ret, inheritFringe = true) - case ret: Return => - cfgForReturn(ret) - case call: Call if call.name == Operators.logicalAnd => - cfgForAndExpression(call) - case call: Call if call.name == Operators.logicalOr => - cfgForOrExpression(call) - case call: Call if call.name == Operators.conditional => - cfgForConditionalExpression(call) - case call: Call if call.dispatchType == DispatchTypes.INLINED => - cfgForInlinedCall(call) - case block: Block if blockMatches(block) => - cfgForChildren(block) - case _: Block => - cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) - case _: Call | _: FieldIdentifier | _: Identifier | _: Literal | _: Block | _: Unknown => - cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) - case _ => - cfgForChildren(node) - } - - private def isLogicalOperator(node: AstNode): Boolean = node match { - case call: Call => - call.name == Operators.conditional || call.name == Operators.logicalOr || call.name == Operators.logicalAnd - case _ => false - } - - private def isInlinedCall(node: AstNode): Boolean = node match { - case call: Call => call.dispatchType == DispatchTypes.INLINED - case _ => false - } - - /** Only include block nodes that do not describe the entire method body or the bodies of control structures or - * inlined calls or logical operators. - */ - private def blockMatches(block: Block): Boolean = { - if (block._astIn.hasNext) { - val parentNode = block.astParent - parentNode.isMethod || parentNode.isControlStructure || isLogicalOperator(parentNode) || isInlinedCall(parentNode) - } else false - } - - /** A second layer of dispatching for control structures. This could as well be part of `cfgFor` and has only been - * placed into a separate function to increase readability. - */ - protected def cfgForControlStructure(node: ControlStructure): Cfg = - node.controlStructureType match { - case ControlStructureTypes.BREAK => - cfgForBreakStatement(node) - case ControlStructureTypes.CONTINUE => - cfgForContinueStatement(node) - case ControlStructureTypes.WHILE => - cfgForWhileStatement(node) - case ControlStructureTypes.DO => - cfgForDoStatement(node) - case ControlStructureTypes.FOR => - cfgForForStatement(node) - case ControlStructureTypes.GOTO => - cfgForGotoStatement(node) - case ControlStructureTypes.IF => - cfgForIfStatement(node) - case ControlStructureTypes.ELSE => - cfgForChildren(node) - case ControlStructureTypes.SWITCH => - cfgForSwitchStatement(node) - case ControlStructureTypes.TRY => - cfgForTryStatement(node) - case ControlStructureTypes.MATCH => - cfgForMatchExpression(node) - case _ => - Cfg.empty - } - - /** The CFG for a break/continue statements contains only the break/continue statement as a single entry node. The - * fringe is empty, that is, appending another CFG to the break statement will not result in the creation of an edge - * from the break statement to the entry point of the other CFG. Labeled breaks are treated like gotos and are added - * to "jumpsToLabel". - */ - protected def cfgForBreakStatement(node: ControlStructure): Cfg = { - node.astChildren.find(_.order == 1) match { - case Some(jumpLabel: JumpLabel) => - val labelName = jumpLabel.name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case Some(literal: Literal) => - // In case we find a literal, it is assumed to be an integer literal which - // indicates how many loop/switch levels the break shall apply to. - val numberOfLevels = Integer.valueOf(literal.code) - Cfg(entryNode = Option(node), breaks = List((node, numberOfLevels))) - case Some(_) => - throw new NotImplementedError( - "Only jump labels and integer literals are currently supported for break statements." - ) - case None => - Cfg(entryNode = Option(node), breaks = List((node, 1))) - } - } - - protected def cfgForContinueStatement(node: ControlStructure): Cfg = { - node.astChildren.find(_.order == 1) match { - case Some(jumpLabel: JumpLabel) => - val labelName = jumpLabel.name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case Some(literal: Literal) => - // In case we find a literal, it is assumed to be an integer literal which - // indicates how many loop levels the continue shall apply to. - val numberOfLevels = Integer.valueOf(literal.code) - Cfg(entryNode = Option(node), continues = List((node, numberOfLevels))) - case Some(_) => - throw new NotImplementedError( - "Only jump labels and integer literals are currently supported for continue statements." - ) - case None => - Cfg(entryNode = Option(node), continues = List((node, 1))) - } - } - - /** Jump targets ("labels") are included in the CFG. As these should be connected to the next appended CFG, we specify - * that the label node is both the entry node and the only node in the fringe. This is achieved by calling - * `cfgForSingleNode` on the label node. Just like for breaks and continues, we record labels. We store case/default - * labels separately from other labels, but that is not a relevant implementation detail. - */ - protected def cfgForJumpTarget(n: JumpTarget): Cfg = { - val labelName = n.name - val cfg = cfgForSingleNode(n) - if (labelName.startsWith("case") || labelName.startsWith("default")) { - cfg.copy(caseLabels = List(n)) - } else { - cfg.copy(labeledNodes = Map(labelName -> n)) - } - } - - /** A CFG for a goto statement is one containing the goto node as an entry node and an empty fringe. Moreover, we - * store the goto for dispatching with `withResolvedJumpToLabel` once the CFG for the entire method has been - * calculated. - */ - protected def cfgForGotoStatement(node: ControlStructure): Cfg = { - node.astChildren.find(_.order == 1) match { - case Some(jumpLabel) => - val labelName = jumpLabel.asInstanceOf[JumpLabel].name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case None => - // Support for old format where the label name is parsed from the code field. - val target = node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)) - target.map(t => Cfg(entryNode = Some(node), jumpsToLabel = List((node, t)))).getOrElse(Cfg.empty) - } - } - - /** Return statements may contain expressions as return values, and therefore, the CFG for a return statement consists - * of the CFG for calculation of that expression, appended to a CFG containing only the return node, connected with a - * single edge to the method exit node. The fringe is empty. - * - * @param inheritFringe - * indicates if the resulting Cfg object must contain the fringe value of the return value's children. - */ - protected def cfgForReturn(actualRet: Return, inheritFringe: Boolean = false): Cfg = { - val childrenCfg = cfgForChildren(actualRet) - childrenCfg ++ - Cfg( - entryNode = Option(actualRet), - edges = singleEdge(actualRet, exitNode), - if (inheritFringe) childrenCfg.fringe else List() - ) - } - - /** The right hand side of a logical AND expression is only evaluated if the left hand side is true as the entire - * expression can only be true if both expressions are true. This is encoded in the corresponding control flow graph - * by creating control flow graphs for the left and right hand expressions and appending the two, where the fringe - * edge type of the left CFG is `TrueEdge`. - */ - protected def cfgForAndExpression(call: Call): Cfg = { - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo(leftCfg, rightCfg.entryNode, TrueEdge) ++ leftCfg.edges ++ rightCfg.edges - Cfg - .from(leftCfg, rightCfg) - .copy( - entryNode = leftCfg.entryNode, - edges = diffGraphs, - fringe = leftCfg.fringe ++ rightCfg.fringe - ) ++ cfgForSingleNode(call) - } - - /** Same construction recipe as for the AND expression, just that the fringe edge type of the left CFG is `FalseEdge`. - */ - protected def cfgForOrExpression(call: Call): Cfg = { - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo(leftCfg, rightCfg.entryNode, FalseEdge) ++ leftCfg.edges ++ rightCfg.edges - Cfg - .from(leftCfg, rightCfg) - .copy( - entryNode = leftCfg.entryNode, - edges = diffGraphs, - fringe = leftCfg.fringe ++ rightCfg.fringe - ) ++ cfgForSingleNode(call) - } - - /** A conditional expression is of the form `condition ? trueExpr ; falseExpr` where both `trueExpr` and `falseExpr` - * are optional. We create the corresponding CFGs by creating CFGs for the three expressions and adding edges between - * them. The new entry node is the condition entry node. - */ - protected def cfgForConditionalExpression(call: Call): Cfg = { - val conditionCfg = cfgFor(call.argument(1)) - val trueCfg = call.argumentOption(2).map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = call.argumentOption(3).map(cfgFor).getOrElse(Cfg.empty) - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode, TrueEdge) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode, FalseEdge) - - val trueFridge = if (trueCfg.entryNode.isDefined) { - trueCfg.fringe - } else { - conditionCfg.fringe.withEdgeType(TrueEdge) - } - val falseFridge = if (falseCfg.entryNode.isDefined) { - falseCfg.fringe - } else { - conditionCfg.fringe.withEdgeType(FalseEdge) - } - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges ++ diffGraphs, - fringe = trueFridge ++ falseFridge - ) ++ cfgForSingleNode(call) - } - - /** For macros, the AST contains a CALL node, along with child sub trees for all arguments, and a final sub tree that - * contains the inlined code. The corresponding CFG consists of the CFG for the call, an edge to the exit and an edge - * to the CFG of the inlined code. We choose this representation because it allows both queries that use the macro - * reference as well as queries that reference the inline code to be chosen as sources/sinks in data flow queries. - */ - def cfgForInlinedCall(call: Call): Cfg = { - val cfgForMacroCall = call.argument.l - .map(cfgFor) - .reduceOption((x, y) => x ++ y) - .getOrElse(Cfg.empty) ++ cfgForSingleNode(call) - val cfgForExpansion = call.astChildren.lastOption.map(cfgFor).getOrElse(Cfg.empty) - val cfg = Cfg - .from(cfgForMacroCall, cfgForExpansion) - .copy( - entryNode = cfgForMacroCall.entryNode, - edges = cfgForMacroCall.edges ++ cfgForExpansion.edges ++ cfgForExpansion.entryNode.toList - .flatMap(x => singleEdge(call, x)), - fringe = cfgForMacroCall.fringe ++ cfgForExpansion.fringe - ) - cfg - } - - /** A for statement is of the form `for(initExpr; condition; loopExpr) body` and all four components may be empty. The - * sequence (condition - body - loopExpr) form the inner part of the loop and we calculate the corresponding CFG - * `innerCfg` so that it is no longer relevant which of these three actually exist and we still have an entry node - * for the loop and a fringe. - */ - protected def cfgForForStatement(node: ControlStructure): Cfg = { - val children = node.astChildren.l - val nLocals = children.count(_.isLocal) - val initExprCfg = children.find(_.order == nLocals + 1).map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = children.find(_.order == nLocals + 2).map(cfgFor).getOrElse(Cfg.empty) - val loopExprCfg = children.find(_.order == nLocals + 3).map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = children.find(_.order == nLocals + 4).map(cfgFor).getOrElse(Cfg.empty) - - val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg - val entryNode = (initExprCfg ++ innerCfg).entryNode - - val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { - if (loopExprCfg.entryNode.isDefined) { - edges(takeCurrentLevel(bodyCfg.continues), loopExprCfg.entryNode) - } else { - edges(takeCurrentLevel(bodyCfg.continues), innerCfg.entryNode) +class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder): + + import io.appthreat.x2cpg.passes.controlflow.cfgcreation.Cfg.* + import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator.* + + /** Control flow graph definitions often feature a designated entry and exit node for each + * method. While these nodes are no-ops from a computational point of view, they are useful to + * guarantee that a method has exactly one entry and one exit. + * + * For the CPG-based control flow graph, we do not need to introduce fake entry and exit node. + * Instead, we can use the METHOD and METHOD_RETURN nodes as entry and exit nodes respectively. + * Note that METHOD_RETURN nodes are the nodes representing formal return parameters, of which + * there exists exactly one per method. + */ + private val exitNode: MethodReturn = entryNode.methodReturn + + /** We return the CFG as a sequence of Diff Graphs that is calculated by first obtaining the CFG + * for the method and then resolving gotos. + */ + def run(): Unit = + cfgForMethod(entryNode).withResolvedJumpToLabel().edges.foreach { edge => + // TODO: we are ignoring edge.edgeType because the + // CFG spec doesn't define an edge type at the moment + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CFG) } - } - - Cfg - .from(initExprCfg, conditionCfg, loopExprCfg, bodyCfg) - .copy( - entryNode = entryNode, - edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(bodyCfg.breaks).map((_, AlwaysEdge)), - breaks = reduceAndFilterLevel(bodyCfg.breaks), - continues = reduceAndFilterLevel(bodyCfg.continues) - ) - } - - /** A Do-Statement is of the form `do body while(condition)` where body may be empty. We again first calculate the - * inner CFG as bodyCfg ++ conditionCfg and then connect edges according to the semantics of do-while. - */ - protected def cfgForDoStatement(node: ControlStructure): Cfg = { - val bodyCfg = node.astChildren.where(_.order(1)).headOption.map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val innerCfg = bodyCfg ++ conditionCfg - - val diffGraphs = - edges(takeCurrentLevel(bodyCfg.continues), conditionCfg.entryNode) ++ - edgesFromFringeTo(bodyCfg, conditionCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, innerCfg.entryNode, TrueEdge) - - Cfg - .from(bodyCfg, conditionCfg, innerCfg) - .copy( - entryNode = if (bodyCfg != Cfg.empty) { bodyCfg.entryNode } - else { conditionCfg.entryNode }, - edges = diffGraphs ++ bodyCfg.edges ++ conditionCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(bodyCfg.breaks).map((_, AlwaysEdge)), - breaks = reduceAndFilterLevel(bodyCfg.breaks), - continues = reduceAndFilterLevel(bodyCfg.continues) - ) - } - - /** CFG creation for while statements of the form `while(condition) body1 else body2` where body1 and the else block - * are optional. - */ - protected def cfgForWhileStatement(node: ControlStructure): Cfg = { - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) - - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(trueCfg, falseCfg.entryNode) ++ - edgesFromFringeTo(trueCfg, conditionCfg.entryNode) ++ - edges(takeCurrentLevel(trueCfg.continues), conditionCfg.entryNode) - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(trueCfg.breaks) - .map((_, AlwaysEdge)) ++ falseCfg.fringe, - breaks = reduceAndFilterLevel(trueCfg.breaks), - continues = reduceAndFilterLevel(trueCfg.continues) - ) - } - - /** CFG creation for switch statements of the form `switch { case condition: ... }`. - */ - protected def cfgForSwitchStatement(node: ControlStructure): Cfg = { - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - - cfgForSwitchLike(conditionCfg, bodyCfg :: Nil) - } - - /** CFG creation for if statements of the form `if(condition) body`, optionally followed by `else body2`. - */ - protected def cfgForIfStatement(node: ControlStructure): Cfg = { - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) - - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode) - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = trueCfg.fringe ++ { - if (falseCfg.entryNode.isDefined) { + + /** Conversion of a method to a CFG, showing the decomposition of the control flow graph + * generation problem into that of translating sub trees according to the node type. In the + * particular case of a method, the CFG is obtained by creating a CFG containing the single + * method node and a fringe containing the node and an outgoing AlwaysEdge, to the CFG obtained + * by translating child CFGs one by one and appending them. + */ + private def cfgForMethod(node: Method): Cfg = + cfgForSingleNode(node) ++ cfgForChildren(node) + + /** For any single AST node, we can construct a CFG containing that single node by setting it as + * the entry node and placing it in the fringe. + */ + private def cfgForSingleNode(node: CfgNode): Cfg = + Cfg(entryNode = Option(node), fringe = List((node, AlwaysEdge))) + + /** The CFG for all children is obtained by translating child ASTs one by one from left to right + * and appending them. + */ + private def cfgForChildren(node: AstNode): Cfg = + node.astChildren.l.map(cfgFor).reduceOption((x, y) => x ++ y).getOrElse(Cfg.empty) + + /** Returns true if this node is a child to some `try` control structure, false if otherwise. + */ + private def withinATryBlock(x: AstNode): Boolean = + x.inAst.isControlStructure.exists(_.controlStructureType == ControlStructureTypes.TRY) + + /** This method dispatches AST nodes by type and calls corresponding conversion methods. + */ + protected def cfgFor(node: AstNode): Cfg = + node match + case _: Method | _: MethodParameterIn | _: Modifier | _: Local | _: TypeDecl | _: Member => + Cfg.empty + case _: MethodRef | _: TypeRef | _: MethodReturn => + cfgForSingleNode(node.asInstanceOf[CfgNode]) + case controlStructure: ControlStructure => + cfgForControlStructure(controlStructure) + case jumpTarget: JumpTarget => + cfgForJumpTarget(jumpTarget) + case ret: Return if withinATryBlock(ret) => + cfgForReturn(ret, inheritFringe = true) + case ret: Return => + cfgForReturn(ret) + case call: Call if call.name == Operators.logicalAnd => + cfgForAndExpression(call) + case call: Call if call.name == Operators.logicalOr => + cfgForOrExpression(call) + case call: Call if call.name == Operators.conditional => + cfgForConditionalExpression(call) + case call: Call if call.dispatchType == DispatchTypes.INLINED => + cfgForInlinedCall(call) + case block: Block if blockMatches(block) => + cfgForChildren(block) + case _: Block => + cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) + case _: Call | _: FieldIdentifier | _: Identifier | _: Literal | _: Block | _: Unknown => + cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) + case _ => + cfgForChildren(node) + + private def isLogicalOperator(node: AstNode): Boolean = node match + case call: Call => + call.name == Operators.conditional || call.name == Operators.logicalOr || call.name == Operators.logicalAnd + case _ => false + + private def isInlinedCall(node: AstNode): Boolean = node match + case call: Call => call.dispatchType == DispatchTypes.INLINED + case _ => false + + /** Only include block nodes that do not describe the entire method body or the bodies of + * control structures or inlined calls or logical operators. + */ + private def blockMatches(block: Block): Boolean = + if block._astIn.hasNext then + val parentNode = block.astParent + parentNode.isMethod || parentNode.isControlStructure || isLogicalOperator( + parentNode + ) || isInlinedCall(parentNode) + else false + + /** A second layer of dispatching for control structures. This could as well be part of `cfgFor` + * and has only been placed into a separate function to increase readability. + */ + protected def cfgForControlStructure(node: ControlStructure): Cfg = + node.controlStructureType match + case ControlStructureTypes.BREAK => + cfgForBreakStatement(node) + case ControlStructureTypes.CONTINUE => + cfgForContinueStatement(node) + case ControlStructureTypes.WHILE => + cfgForWhileStatement(node) + case ControlStructureTypes.DO => + cfgForDoStatement(node) + case ControlStructureTypes.FOR => + cfgForForStatement(node) + case ControlStructureTypes.GOTO => + cfgForGotoStatement(node) + case ControlStructureTypes.IF => + cfgForIfStatement(node) + case ControlStructureTypes.ELSE => + cfgForChildren(node) + case ControlStructureTypes.SWITCH => + cfgForSwitchStatement(node) + case ControlStructureTypes.TRY => + cfgForTryStatement(node) + case ControlStructureTypes.MATCH => + cfgForMatchExpression(node) + case _ => + Cfg.empty + + /** The CFG for a break/continue statements contains only the break/continue statement as a + * single entry node. The fringe is empty, that is, appending another CFG to the break + * statement will not result in the creation of an edge from the break statement to the entry + * point of the other CFG. Labeled breaks are treated like gotos and are added to + * "jumpsToLabel". + */ + protected def cfgForBreakStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel: JumpLabel) => + val labelName = jumpLabel.name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case Some(literal: Literal) => + // In case we find a literal, it is assumed to be an integer literal which + // indicates how many loop/switch levels the break shall apply to. + val numberOfLevels = Integer.valueOf(literal.code) + Cfg(entryNode = Option(node), breaks = List((node, numberOfLevels))) + case Some(_) => + throw new NotImplementedError( + "Only jump labels and integer literals are currently supported for break statements." + ) + case None => + Cfg(entryNode = Option(node), breaks = List((node, 1))) + + protected def cfgForContinueStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel: JumpLabel) => + val labelName = jumpLabel.name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case Some(literal: Literal) => + // In case we find a literal, it is assumed to be an integer literal which + // indicates how many loop levels the continue shall apply to. + val numberOfLevels = Integer.valueOf(literal.code) + Cfg(entryNode = Option(node), continues = List((node, numberOfLevels))) + case Some(_) => + throw new NotImplementedError( + "Only jump labels and integer literals are currently supported for continue statements." + ) + case None => + Cfg(entryNode = Option(node), continues = List((node, 1))) + + /** Jump targets ("labels") are included in the CFG. As these should be connected to the next + * appended CFG, we specify that the label node is both the entry node and the only node in the + * fringe. This is achieved by calling `cfgForSingleNode` on the label node. Just like for + * breaks and continues, we record labels. We store case/default labels separately from other + * labels, but that is not a relevant implementation detail. + */ + protected def cfgForJumpTarget(n: JumpTarget): Cfg = + val labelName = n.name + val cfg = cfgForSingleNode(n) + if labelName.startsWith("case") || labelName.startsWith("default") then + cfg.copy(caseLabels = List(n)) + else + cfg.copy(labeledNodes = Map(labelName -> n)) + + /** A CFG for a goto statement is one containing the goto node as an entry node and an empty + * fringe. Moreover, we store the goto for dispatching with `withResolvedJumpToLabel` once the + * CFG for the entire method has been calculated. + */ + protected def cfgForGotoStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel) => + val labelName = jumpLabel.asInstanceOf[JumpLabel].name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case None => + // Support for old format where the label name is parsed from the code field. + val target = node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)) + target.map(t => + Cfg(entryNode = Some(node), jumpsToLabel = List((node, t))) + ).getOrElse(Cfg.empty) + + /** Return statements may contain expressions as return values, and therefore, the CFG for a + * return statement consists of the CFG for calculation of that expression, appended to a CFG + * containing only the return node, connected with a single edge to the method exit node. The + * fringe is empty. + * + * @param inheritFringe + * indicates if the resulting Cfg object must contain the fringe value of the return value's + * children. + */ + protected def cfgForReturn(actualRet: Return, inheritFringe: Boolean = false): Cfg = + val childrenCfg = cfgForChildren(actualRet) + childrenCfg ++ + Cfg( + entryNode = Option(actualRet), + edges = singleEdge(actualRet, exitNode), + if inheritFringe then childrenCfg.fringe else List() + ) + + /** The right hand side of a logical AND expression is only evaluated if the left hand side is + * true as the entire expression can only be true if both expressions are true. This is encoded + * in the corresponding control flow graph by creating control flow graphs for the left and + * right hand expressions and appending the two, where the fringe edge type of the left CFG is + * `TrueEdge`. + */ + protected def cfgForAndExpression(call: Call): Cfg = + val leftCfg = cfgFor(call.argument(1)) + val rightCfg = cfgFor(call.argument(2)) + val diffGraphs = edgesFromFringeTo( + leftCfg, + rightCfg.entryNode, + TrueEdge + ) ++ leftCfg.edges ++ rightCfg.edges + Cfg + .from(leftCfg, rightCfg) + .copy( + entryNode = leftCfg.entryNode, + edges = diffGraphs, + fringe = leftCfg.fringe ++ rightCfg.fringe + ) ++ cfgForSingleNode(call) + end cfgForAndExpression + + /** Same construction recipe as for the AND expression, just that the fringe edge type of the + * left CFG is `FalseEdge`. + */ + protected def cfgForOrExpression(call: Call): Cfg = + val leftCfg = cfgFor(call.argument(1)) + val rightCfg = cfgFor(call.argument(2)) + val diffGraphs = edgesFromFringeTo( + leftCfg, + rightCfg.entryNode, + FalseEdge + ) ++ leftCfg.edges ++ rightCfg.edges + Cfg + .from(leftCfg, rightCfg) + .copy( + entryNode = leftCfg.entryNode, + edges = diffGraphs, + fringe = leftCfg.fringe ++ rightCfg.fringe + ) ++ cfgForSingleNode(call) + + /** A conditional expression is of the form `condition ? trueExpr ; falseExpr` where both + * `trueExpr` and `falseExpr` are optional. We create the corresponding CFGs by creating CFGs + * for the three expressions and adding edges between them. The new entry node is the condition + * entry node. + */ + protected def cfgForConditionalExpression(call: Call): Cfg = + val conditionCfg = cfgFor(call.argument(1)) + val trueCfg = call.argumentOption(2).map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = call.argumentOption(3).map(cfgFor).getOrElse(Cfg.empty) + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode, TrueEdge) ++ + edgesFromFringeTo(conditionCfg, falseCfg.entryNode, FalseEdge) + + val trueFridge = if trueCfg.entryNode.isDefined then + trueCfg.fringe + else + conditionCfg.fringe.withEdgeType(TrueEdge) + val falseFridge = if falseCfg.entryNode.isDefined then falseCfg.fringe - } else { + else conditionCfg.fringe.withEdgeType(FalseEdge) - } + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges ++ diffGraphs, + fringe = trueFridge ++ falseFridge + ) ++ cfgForSingleNode(call) + end cfgForConditionalExpression + + /** For macros, the AST contains a CALL node, along with child sub trees for all arguments, and + * a final sub tree that contains the inlined code. The corresponding CFG consists of the CFG + * for the call, an edge to the exit and an edge to the CFG of the inlined code. We choose this + * representation because it allows both queries that use the macro reference as well as + * queries that reference the inline code to be chosen as sources/sinks in data flow queries. + */ + def cfgForInlinedCall(call: Call): Cfg = + val cfgForMacroCall = call.argument.l + .map(cfgFor) + .reduceOption((x, y) => x ++ y) + .getOrElse(Cfg.empty) ++ cfgForSingleNode(call) + val cfgForExpansion = call.astChildren.lastOption.map(cfgFor).getOrElse(Cfg.empty) + val cfg = Cfg + .from(cfgForMacroCall, cfgForExpansion) + .copy( + entryNode = cfgForMacroCall.entryNode, + edges = + cfgForMacroCall.edges ++ cfgForExpansion.edges ++ cfgForExpansion.entryNode.toList + .flatMap(x => singleEdge(call, x)), + fringe = cfgForMacroCall.fringe ++ cfgForExpansion.fringe + ) + cfg + end cfgForInlinedCall + + /** A for statement is of the form `for(initExpr; condition; loopExpr) body` and all four + * components may be empty. The sequence (condition - body - loopExpr) form the inner part of + * the loop and we calculate the corresponding CFG `innerCfg` so that it is no longer relevant + * which of these three actually exist and we still have an entry node for the loop and a + * fringe. + */ + protected def cfgForForStatement(node: ControlStructure): Cfg = + val children = node.astChildren.l + val nLocals = children.count(_.isLocal) + val initExprCfg = children.find(_.order == nLocals + 1).map(cfgFor).getOrElse(Cfg.empty) + val conditionCfg = children.find(_.order == nLocals + 2).map(cfgFor).getOrElse(Cfg.empty) + val loopExprCfg = children.find(_.order == nLocals + 3).map(cfgFor).getOrElse(Cfg.empty) + val bodyCfg = children.find(_.order == nLocals + 4).map(cfgFor).getOrElse(Cfg.empty) + + val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg + val entryNode = (initExprCfg ++ innerCfg).entryNode + + val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ + edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { + if loopExprCfg.entryNode.isDefined then + edges(takeCurrentLevel(bodyCfg.continues), loopExprCfg.entryNode) + else + edges(takeCurrentLevel(bodyCfg.continues), innerCfg.entryNode) + } + + Cfg + .from(initExprCfg, conditionCfg, loopExprCfg, bodyCfg) + .copy( + entryNode = entryNode, + edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, + fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( + bodyCfg.breaks + ).map((_, AlwaysEdge)), + breaks = reduceAndFilterLevel(bodyCfg.breaks), + continues = reduceAndFilterLevel(bodyCfg.continues) + ) + end cfgForForStatement + + /** A Do-Statement is of the form `do body while(condition)` where body may be empty. We again + * first calculate the inner CFG as bodyCfg ++ conditionCfg and then connect edges according to + * the semantics of do-while. + */ + protected def cfgForDoStatement(node: ControlStructure): Cfg = + val bodyCfg = node.astChildren.where(_.order(1)).headOption.map(cfgFor).getOrElse(Cfg.empty) + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val innerCfg = bodyCfg ++ conditionCfg + + val diffGraphs = + edges(takeCurrentLevel(bodyCfg.continues), conditionCfg.entryNode) ++ + edgesFromFringeTo(bodyCfg, conditionCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, innerCfg.entryNode, TrueEdge) + + Cfg + .from(bodyCfg, conditionCfg, innerCfg) + .copy( + entryNode = if bodyCfg != Cfg.empty then bodyCfg.entryNode + else conditionCfg.entryNode, + edges = diffGraphs ++ bodyCfg.edges ++ conditionCfg.edges, + fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( + bodyCfg.breaks + ).map((_, AlwaysEdge)), + breaks = reduceAndFilterLevel(bodyCfg.breaks), + continues = reduceAndFilterLevel(bodyCfg.continues) + ) + end cfgForDoStatement + + /** CFG creation for while statements of the form `while(condition) body1 else body2` where + * body1 and the else block are optional. + */ + protected def cfgForWhileStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) + + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ + edgesFromFringeTo(trueCfg, falseCfg.entryNode) ++ + edgesFromFringeTo(trueCfg, conditionCfg.entryNode) ++ + edges(takeCurrentLevel(trueCfg.continues), conditionCfg.entryNode) + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, + fringe = + conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(trueCfg.breaks) + .map((_, AlwaysEdge)) ++ falseCfg.fringe, + breaks = reduceAndFilterLevel(trueCfg.breaks), + continues = reduceAndFilterLevel(trueCfg.continues) + ) + end cfgForWhileStatement + + /** CFG creation for switch statements of the form `switch { case condition: ... }`. + */ + protected def cfgForSwitchStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val bodyCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + + cfgForSwitchLike(conditionCfg, bodyCfg :: Nil) + + /** CFG creation for if statements of the form `if(condition) body`, optionally followed by + * `else body2`. + */ + protected def cfgForIfStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) + + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, falseCfg.entryNode) + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, + fringe = trueCfg.fringe ++ { + if falseCfg.entryNode.isDefined then + falseCfg.fringe + else + conditionCfg.fringe.withEdgeType(FalseEdge) + } + ) + end cfgForIfStatement + + /** CFG creation for try statements of the form `try { tryBody ] catch { catchBody } `, + * optionally followed by `finally { finallyBody }`. + * + * To avoid very large CFGs for try statements, only edges from the last statement in the `try` + * block to each `catch` block (and optionally the `finally` block) are created. The last + * statement in each `catch` block should then have an outgoing edge to the `finally` block if + * it exists (and not to any subsequent catch blocks), or otherwise * be part of the fringe. + * + * By default, the first child of the `TRY` node is treated as the try body, while every + * subsequent node is treated as a `catch`, with no `finally` present. To treat the last child + * of the node as the `finally` block, the `code` field of the `Block` node must be set to + * `finally`. + */ + protected def cfgForTryStatement(node: ControlStructure): Cfg = + val maybeTryBlock = + node.astChildren + .where(_.order(1)) + .where(_.astChildren) // Filter out empty `try` bodies + .headOption + + val tryBodyCfg: Cfg = maybeTryBlock.map(cfgFor).getOrElse(Cfg.empty) + + val catchBodyCfgs: List[Cfg] = + node.astChildren + .where(_.order(2)) + .toList match + case Nil => List(Cfg.empty) + case asts => asts.map(cfgFor) + + val maybeFinallyBodyCfg: List[Cfg] = + node.astChildren + .where(_.order(3)) + .map(cfgFor) + .headOption // Assume there can only be one + .toList + + val tryToCatchEdges = catchBodyCfgs.flatMap { catchBodyCfg => + edgesFromFringeTo(tryBodyCfg, catchBodyCfg.entryNode) } - ) - } - - /** CFG creation for try statements of the form `try { tryBody ] catch { catchBody } `, optionally followed by - * `finally { finallyBody }`. - * - * To avoid very large CFGs for try statements, only edges from the last statement in the `try` block to each `catch` - * block (and optionally the `finally` block) are created. The last statement in each `catch` block should then have - * an outgoing edge to the `finally` block if it exists (and not to any subsequent catch blocks), or otherwise * be - * part of the fringe. - * - * By default, the first child of the `TRY` node is treated as the try body, while every subsequent node is treated - * as a `catch`, with no `finally` present. To treat the last child of the node as the `finally` block, the `code` - * field of the `Block` node must be set to `finally`. - */ - protected def cfgForTryStatement(node: ControlStructure): Cfg = { - val maybeTryBlock = - node.astChildren - .where(_.order(1)) - .where(_.astChildren) // Filter out empty `try` bodies - .headOption - - val tryBodyCfg: Cfg = maybeTryBlock.map(cfgFor).getOrElse(Cfg.empty) - - val catchBodyCfgs: List[Cfg] = - node.astChildren - .where(_.order(2)) - .toList match { - case Nil => List(Cfg.empty) - case asts => asts.map(cfgFor) - } - - val maybeFinallyBodyCfg: List[Cfg] = - node.astChildren - .where(_.order(3)) - .map(cfgFor) - .headOption // Assume there can only be one - .toList - - val tryToCatchEdges = catchBodyCfgs.flatMap { catchBodyCfg => - edgesFromFringeTo(tryBodyCfg, catchBodyCfg.entryNode) - } - - val catchToFinallyEdges = ( - for ( - catchBodyCfg <- catchBodyCfgs; - finallyBodyCfg <- maybeFinallyBodyCfg - ) yield edgesFromFringeTo(catchBodyCfg, finallyBodyCfg.entryNode) - ).flatten - - val tryToFinallyEdges = maybeFinallyBodyCfg.flatMap { cfg => - edgesFromFringeTo(tryBodyCfg, cfg.entryNode) - } - - val diffGraphs = tryToCatchEdges ++ catchToFinallyEdges ++ tryToFinallyEdges - - if (maybeTryBlock.isEmpty) { - // This case deals with the situation where the try block is empty. In this case, - // no catch block can be executed since nothing can be thrown, but the finally block - // will still be executed. - maybeFinallyBodyCfg.headOption.getOrElse(Cfg.empty) - } else { - Cfg - .from(Seq(tryBodyCfg) ++ catchBodyCfgs ++ maybeFinallyBodyCfg: _*) - .copy( - entryNode = tryBodyCfg.entryNode, - edges = - diffGraphs ++ tryBodyCfg.edges ++ catchBodyCfgs.flatMap(_.edges) ++ maybeFinallyBodyCfg.flatMap(_.edges), - fringe = if (maybeFinallyBodyCfg.flatMap(_.entryNode).nonEmpty) { - maybeFinallyBodyCfg.head.fringe - } else { - tryBodyCfg.fringe ++ catchBodyCfgs.flatMap(_.fringe) - } - ) - } - } - - /** The CFGs for match cases are modeled after PHP match expressions and assumes that a case will always consist of - * one or more JumpTargets followed by a single expression. The CFG also assumes an implicit `break` at the end of - * each match case. - */ - protected def cfgsForMatchCases(body: AstNode): List[Cfg] = { - body.astChildren - .foldLeft(List(Cfg.empty)) { - case (currCfg :: prevCfgs, astNode) => - astNode match { - case jumpTarget: JumpTarget => - val jumpCfg = cfgForJumpTarget(jumpTarget) - (currCfg ++ jumpCfg) :: prevCfgs - - case node: AstNode => - val nodeCfg = cfgFor(node) - Cfg.empty :: (currCfg ++ nodeCfg) :: prevCfgs - } - case _ => List.empty - } - .reverse - } - - /** CFG creation for match expressions of the form `match { case condition: expr ... }` - */ - protected def cfgForMatchExpression(node: ControlStructure): Cfg = { - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val bodyCfgs = node.whenTrue.headOption.map(cfgsForMatchCases).getOrElse(Nil) - - cfgForSwitchLike(conditionCfg, bodyCfgs) - } - - protected def cfgForSwitchLike(conditionCfg: Cfg, bodyCfgs: List[Cfg]): Cfg = { - val hasDefaultCase = bodyCfgs.flatMap(_.caseLabels).exists(x => x.asInstanceOf[JumpTarget].name == "default") - val caseEdges = edgesToMultiple(conditionCfg.fringe.map(_._1), bodyCfgs.flatMap(_.caseLabels), CaseEdge) - val breakFringe = takeCurrentLevel(bodyCfgs.flatMap(_.breaks)).map((_, AlwaysEdge)) - - Cfg - .from(conditionCfg :: bodyCfgs: _*) - .copy( - entryNode = conditionCfg.entryNode, - edges = caseEdges ++ conditionCfg.edges ++ bodyCfgs.flatMap(_.edges), - fringe = { - if (!hasDefaultCase) { conditionCfg.fringe.withEdgeType(FalseEdge) } - else { Nil } - } ++ breakFringe ++ bodyCfgs.flatMap(_.fringe), - caseLabels = List(), - breaks = reduceAndFilterLevel(bodyCfgs.flatMap(_.breaks)), - continues = bodyCfgs.flatMap(_.continues) - ) - } -} -object CfgCreator { + val catchToFinallyEdges = ( + for ( + catchBodyCfg <- catchBodyCfgs; + finallyBodyCfg <- maybeFinallyBodyCfg + ) yield edgesFromFringeTo(catchBodyCfg, finallyBodyCfg.entryNode) + ).flatten - implicit class FringeWrapper(fringe: List[(CfgNode, CfgEdgeType)]) { - def withEdgeType(edgeType: CfgEdgeType): List[(CfgNode, CfgEdgeType)] = { - fringe.map { case (x, _) => (x, edgeType) } - } - } + val tryToFinallyEdges = maybeFinallyBodyCfg.flatMap { cfg => + edgesFromFringeTo(tryBodyCfg, cfg.entryNode) + } -} + val diffGraphs = tryToCatchEdges ++ catchToFinallyEdges ++ tryToFinallyEdges + + if maybeTryBlock.isEmpty then + // This case deals with the situation where the try block is empty. In this case, + // no catch block can be executed since nothing can be thrown, but the finally block + // will still be executed. + maybeFinallyBodyCfg.headOption.getOrElse(Cfg.empty) + else + Cfg + .from(Seq(tryBodyCfg) ++ catchBodyCfgs ++ maybeFinallyBodyCfg: _*) + .copy( + entryNode = tryBodyCfg.entryNode, + edges = + diffGraphs ++ tryBodyCfg.edges ++ catchBodyCfgs.flatMap( + _.edges + ) ++ maybeFinallyBodyCfg.flatMap(_.edges), + fringe = if maybeFinallyBodyCfg.flatMap(_.entryNode).nonEmpty then + maybeFinallyBodyCfg.head.fringe + else + tryBodyCfg.fringe ++ catchBodyCfgs.flatMap(_.fringe) + ) + end if + end cfgForTryStatement + + /** The CFGs for match cases are modeled after PHP match expressions and assumes that a case + * will always consist of one or more JumpTargets followed by a single expression. The CFG also + * assumes an implicit `break` at the end of each match case. + */ + protected def cfgsForMatchCases(body: AstNode): List[Cfg] = + body.astChildren + .foldLeft(List(Cfg.empty)) { + case (currCfg :: prevCfgs, astNode) => + astNode match + case jumpTarget: JumpTarget => + val jumpCfg = cfgForJumpTarget(jumpTarget) + (currCfg ++ jumpCfg) :: prevCfgs + + case node: AstNode => + val nodeCfg = cfgFor(node) + Cfg.empty :: (currCfg ++ nodeCfg) :: prevCfgs + case _ => List.empty + } + .reverse + + /** CFG creation for match expressions of the form `match { case condition: expr ... }` + */ + protected def cfgForMatchExpression(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val bodyCfgs = node.whenTrue.headOption.map(cfgsForMatchCases).getOrElse(Nil) + + cfgForSwitchLike(conditionCfg, bodyCfgs) + + protected def cfgForSwitchLike(conditionCfg: Cfg, bodyCfgs: List[Cfg]): Cfg = + val hasDefaultCase = + bodyCfgs.flatMap(_.caseLabels).exists(x => x.asInstanceOf[JumpTarget].name == "default") + val caseEdges = + edgesToMultiple(conditionCfg.fringe.map(_._1), bodyCfgs.flatMap(_.caseLabels), CaseEdge) + val breakFringe = takeCurrentLevel(bodyCfgs.flatMap(_.breaks)).map((_, AlwaysEdge)) + + Cfg + .from(conditionCfg :: bodyCfgs: _*) + .copy( + entryNode = conditionCfg.entryNode, + edges = caseEdges ++ conditionCfg.edges ++ bodyCfgs.flatMap(_.edges), + fringe = { + if !hasDefaultCase then conditionCfg.fringe.withEdgeType(FalseEdge) + else Nil + } ++ breakFringe ++ bodyCfgs.flatMap(_.fringe), + caseLabels = List(), + breaks = reduceAndFilterLevel(bodyCfgs.flatMap(_.breaks)), + continues = bodyCfgs.flatMap(_.continues) + ) + end cfgForSwitchLike +end CfgCreator + +object CfgCreator: + + implicit class FringeWrapper(fringe: List[(CfgNode, CfgEdgeType)]): + def withEdgeType(edgeType: CfgEdgeType): List[(CfgNode, CfgEdgeType)] = + fringe.map { case (x, _) => (x, edgeType) } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala index 70823e8d..59bcb21e 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala @@ -1,6 +1,5 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator -trait CfgAdapter[Node] { - def successors(node: Node): IterableOnce[Node] - def predecessors(node: Node): IterableOnce[Node] -} +trait CfgAdapter[Node]: + def successors(node: Node): IterableOnce[Node] + def predecessors(node: Node): IterableOnce[Node] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala index 19bf856d..5b446ef0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala @@ -3,87 +3,85 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator import io.shiftleft.semanticcpg.language.NodeOrdering import scala.collection.mutable -class CfgDominator[NodeType](adapter: CfgAdapter[NodeType]) { +class CfgDominator[NodeType](adapter: CfgAdapter[NodeType]): - /** Calculates the immediate dominators of all CFG nodes reachable from cfgEntry. Since the cfgEntry does not have an - * immediate dominator, it has no entry in the return map. - * - * The algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy J. Harvey, and Ken - * Kennedy". - */ - def calculate(cfgEntry: NodeType): mutable.LinkedHashMap[NodeType, NodeType] = { - val UNDEFINED = -1 - def expand(x: NodeType) = { adapter.successors(x).iterator } - def expandBack(x: NodeType) = { adapter.predecessors(x).iterator } + /** Calculates the immediate dominators of all CFG nodes reachable from cfgEntry. Since the + * cfgEntry does not have an immediate dominator, it has no entry in the return map. + * + * The algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy + * J. Harvey, and Ken Kennedy". + */ + def calculate(cfgEntry: NodeType): mutable.LinkedHashMap[NodeType, NodeType] = + val UNDEFINED = -1 + def expand(x: NodeType) = adapter.successors(x).iterator + def expandBack(x: NodeType) = adapter.predecessors(x).iterator - val postOrderNumbering = NodeOrdering.postOrderNumbering(cfgEntry, expand) - val nodesInReversePostOrder = NodeOrdering.reverseNodeList(postOrderNumbering.toList).filterNot(_ == cfgEntry) - // Index of each node into dominators array. - val indexOf = postOrderNumbering.withDefaultValue(UNDEFINED) - // We use withDefault because unreachable/dead - // code nodes are not numbered but may be touched - // as predecessors of reachable nodes. + val postOrderNumbering = NodeOrdering.postOrderNumbering(cfgEntry, expand) + val nodesInReversePostOrder = + NodeOrdering.reverseNodeList(postOrderNumbering.toList).filterNot(_ == cfgEntry) + // Index of each node into dominators array. + val indexOf = postOrderNumbering.withDefaultValue(UNDEFINED) + // We use withDefault because unreachable/dead + // code nodes are not numbered but may be touched + // as predecessors of reachable nodes. - val dominators = Array.fill(indexOf.size)(UNDEFINED) - dominators(indexOf(cfgEntry)) = indexOf(cfgEntry) + val dominators = Array.fill(indexOf.size)(UNDEFINED) + dominators(indexOf(cfgEntry)) = indexOf(cfgEntry) - /* Retrieve index of immediate dominator for node with given index. If the index is `UNDEFINED`, UNDEFINED is - * returned. */ - def safeDominators(index: Int): Int = { - if (index != UNDEFINED) { - dominators(index) - } else { - UNDEFINED - } - } + /* Retrieve index of immediate dominator for node with given index. If the index is `UNDEFINED`, UNDEFINED is + * returned. */ + def safeDominators(index: Int): Int = + if index != UNDEFINED then + dominators(index) + else + UNDEFINED - var changed = true - while (changed) { - changed = false - nodesInReversePostOrder.foreach { node => - val firstNotUndefinedPred = expandBack(node).find { predecessor => - safeDominators(indexOf(predecessor)) != UNDEFINED - }.get + var changed = true + while changed do + changed = false + nodesInReversePostOrder.foreach { node => + val firstNotUndefinedPred = expandBack(node).find { predecessor => + safeDominators(indexOf(predecessor)) != UNDEFINED + }.get - var newImmediateDominator = indexOf(firstNotUndefinedPred) - expandBack(node).foreach { predecessor => - val predecessorIndex = indexOf(predecessor) - if (safeDominators(predecessorIndex) != UNDEFINED) { - newImmediateDominator = intersect(dominators, predecessorIndex, newImmediateDominator) - } - } - - val nodeIndex = indexOf(node) - if (dominators(nodeIndex) != newImmediateDominator) { - dominators(nodeIndex) = newImmediateDominator - changed = true - } - } + var newImmediateDominator = indexOf(firstNotUndefinedPred) + expandBack(node).foreach { predecessor => + val predecessorIndex = indexOf(predecessor) + if safeDominators(predecessorIndex) != UNDEFINED then + newImmediateDominator = + intersect(dominators, predecessorIndex, newImmediateDominator) + } - } + val nodeIndex = indexOf(node) + if dominators(nodeIndex) != newImmediateDominator then + dominators(nodeIndex) = newImmediateDominator + changed = true + } + end while - val postOrderNumberingToNode = postOrderNumbering.map { case (node, index) => (index, node) } - - postOrderNumbering.collect { - case (node, index) if node != cfgEntry => - val immediateDominatorIndex = dominators(index) - (node, postOrderNumberingToNode(immediateDominatorIndex)) - } - } + val postOrderNumberingToNode = postOrderNumbering.map { case (node, index) => + (index, node) + } - private def intersect(dominators: Array[Int], immediateDomIndex1: Int, immediateDomIndex2: Int): Int = { - var finger1 = immediateDomIndex1 - var finger2 = immediateDomIndex2 + postOrderNumbering.collect { + case (node, index) if node != cfgEntry => + val immediateDominatorIndex = dominators(index) + (node, postOrderNumberingToNode(immediateDominatorIndex)) + } + end calculate - while (finger1 != finger2) { - while (finger1 < finger2) { - finger1 = dominators(finger1) - } - while (finger2 < finger1) { - finger2 = dominators(finger2) - } - } - finger1 - } + private def intersect( + dominators: Array[Int], + immediateDomIndex1: Int, + immediateDomIndex2: Int + ): Int = + var finger1 = immediateDomIndex1 + var finger2 = immediateDomIndex2 -} + while finger1 != finger2 do + while finger1 < finger2 do + finger1 = dominators(finger1) + while finger2 < finger1 do + finger2 = dominators(finger2) + finger1 +end CfgDominator diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala index 1e4518a9..434d0b93 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala @@ -2,37 +2,40 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator import scala.collection.mutable -/** Calculates the dominator frontier for a set of CFG nodes. The returned multimap associates the frontier nodes to - * each CFG node. +/** Calculates the dominator frontier for a set of CFG nodes. The returned multimap associates the + * frontier nodes to each CFG node. * - * The used algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy J. Harvey, and Ken - * Kennedy". + * The used algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy + * J. Harvey, and Ken Kennedy". */ -class CfgDominatorFrontier[NodeType](cfgAdapter: CfgAdapter[NodeType], domTreeAdapter: DomTreeAdapter[NodeType]) { +class CfgDominatorFrontier[NodeType]( + cfgAdapter: CfgAdapter[NodeType], + domTreeAdapter: DomTreeAdapter[NodeType] +): - private def doms(x: NodeType): Option[NodeType] = domTreeAdapter.immediateDominator(x) - private def pred(x: NodeType): Seq[NodeType] = cfgAdapter.predecessors(x).iterator.to(Seq) - private def onlyJoinNodes(x: NodeType): Option[(NodeType, Seq[NodeType])] = - Option(pred(x)).filter(_.size > 1).map(p => (x, p)) - private def withIDom(x: NodeType, preds: Seq[NodeType]) = - doms(x).map(i => (x, preds, i)) + private def doms(x: NodeType): Option[NodeType] = domTreeAdapter.immediateDominator(x) + private def pred(x: NodeType): Seq[NodeType] = cfgAdapter.predecessors(x).iterator.to(Seq) + private def onlyJoinNodes(x: NodeType): Option[(NodeType, Seq[NodeType])] = + Option(pred(x)).filter(_.size > 1).map(p => (x, p)) + private def withIDom(x: NodeType, preds: Seq[NodeType]) = + doms(x).map(i => (x, preds, i)) - def calculate(cfgNodes: Seq[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = { - val domFrontier = mutable.Map.empty[NodeType, mutable.Set[NodeType]] + def calculate(cfgNodes: Seq[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = + val domFrontier = mutable.Map.empty[NodeType, mutable.Set[NodeType]] - for { - cfgNode <- cfgNodes - (nodeType, joinNodes) <- onlyJoinNodes(cfgNode) - (joinNode, preds, joinNodeIDom) <- withIDom(nodeType, joinNodes) - } preds.foreach { p => - var currentPred = Option(p) - while (currentPred.isDefined && currentPred.get != joinNodeIDom) { - val frontierNodes = domFrontier.getOrElseUpdate(currentPred.get, mutable.Set.empty) - frontierNodes.add(joinNode) - currentPred = doms(currentPred.get) - } - } + for + cfgNode <- cfgNodes + (nodeType, joinNodes) <- onlyJoinNodes(cfgNode) + (joinNode, preds, joinNodeIDom) <- withIDom(nodeType, joinNodes) + do + preds.foreach { p => + var currentPred = Option(p) + while currentPred.isDefined && currentPred.get != joinNodeIDom do + val frontierNodes = + domFrontier.getOrElseUpdate(currentPred.get, mutable.Set.empty) + frontierNodes.add(joinNode) + currentPred = doms(currentPred.get) + } - domFrontier - } -} + domFrontier +end CfgDominatorFrontier diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala index 40e64cd7..e424b989 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala @@ -4,45 +4,41 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{Method, StoredNode} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable /** This pass has no prerequisites. */ -class CfgDominatorPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg) { - override def generateParts(): Array[Method] = cpg.method.toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = { - val cfgAdapter = new CpgCfgAdapter() - val dominatorCalculator = new CfgDominator(cfgAdapter) - - val reverseCfgAdapter = new ReverseCpgCfgAdapter() - val postDominatorCalculator = new CfgDominator(reverseCfgAdapter) - - val cfgNodeToImmediateDominator = dominatorCalculator.calculate(method) - addDomTreeEdges(dstGraph, cfgNodeToImmediateDominator) - - val cfgNodeToPostImmediateDominator = postDominatorCalculator.calculate(method.methodReturn) - addPostDomTreeEdges(dstGraph, cfgNodeToPostImmediateDominator) - - } - - private def addDomTreeEdges( - dstGraph: DiffGraphBuilder, - cfgNodeToImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] - ): Unit = { - cfgNodeToImmediateDominator.foreach { case (node, immediateDominator) => - dstGraph.addEdge(immediateDominator, node, EdgeTypes.DOMINATE) - } - } - - private def addPostDomTreeEdges( - dstGraph: DiffGraphBuilder, - cfgNodeToPostImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] - ): Unit = { - cfgNodeToPostImmediateDominator.foreach { case (node, immediatePostDominator) => - dstGraph.addEdge(immediatePostDominator, node, EdgeTypes.POST_DOMINATE) - } - } -} +class CfgDominatorPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg): + override def generateParts(): Array[Method] = cpg.method.toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + val cfgAdapter = new CpgCfgAdapter() + val dominatorCalculator = new CfgDominator(cfgAdapter) + + val reverseCfgAdapter = new ReverseCpgCfgAdapter() + val postDominatorCalculator = new CfgDominator(reverseCfgAdapter) + + val cfgNodeToImmediateDominator = dominatorCalculator.calculate(method) + addDomTreeEdges(dstGraph, cfgNodeToImmediateDominator) + + val cfgNodeToPostImmediateDominator = postDominatorCalculator.calculate(method.methodReturn) + addPostDomTreeEdges(dstGraph, cfgNodeToPostImmediateDominator) + + private def addDomTreeEdges( + dstGraph: DiffGraphBuilder, + cfgNodeToImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] + ): Unit = + cfgNodeToImmediateDominator.foreach { case (node, immediateDominator) => + dstGraph.addEdge(immediateDominator, node, EdgeTypes.DOMINATE) + } + + private def addPostDomTreeEdges( + dstGraph: DiffGraphBuilder, + cfgNodeToPostImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] + ): Unit = + cfgNodeToPostImmediateDominator.foreach { case (node, immediatePostDominator) => + dstGraph.addEdge(immediatePostDominator, node, EdgeTypes.POST_DOMINATE) + } +end CfgDominatorPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala index 9647f806..f880b578 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala @@ -2,12 +2,10 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator import io.shiftleft.codepropertygraph.generated.nodes.StoredNode -class CpgCfgAdapter extends CfgAdapter[StoredNode] { +class CpgCfgAdapter extends CfgAdapter[StoredNode]: - override def successors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgOut + override def successors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgOut - override def predecessors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgIn - -} + override def predecessors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgIn diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala index 94a28d52..f50385b0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala @@ -1,10 +1,10 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator -trait DomTreeAdapter[Node] { +trait DomTreeAdapter[Node]: - /** Returns the immediate dominator of a cfgNode. The returned value can be None if cfgNode was the cfg entry node - * while calculating the dominator relation or if cfgNode is dead code. In the post dominator case "dead code" means - * code which does lead to the normal method exit. An example would be a thrown excpetion. - */ - def immediateDominator(cfgNode: Node): Option[Node] -} + /** Returns the immediate dominator of a cfgNode. The returned value can be None if cfgNode was + * the cfg entry node while calculating the dominator relation or if cfgNode is dead code. In + * the post dominator case "dead code" means code which does lead to the normal method exit. An + * example would be a thrown excpetion. + */ + def immediateDominator(cfgNode: Node): Option[Node] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala index d8886556..7505aed4 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala @@ -2,12 +2,10 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator import io.shiftleft.codepropertygraph.generated.nodes.StoredNode -class ReverseCpgCfgAdapter extends CfgAdapter[StoredNode] { +class ReverseCpgCfgAdapter extends CfgAdapter[StoredNode]: - override def successors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgIn + override def successors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgIn - override def predecessors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgOut - -} + override def predecessors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgOut diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala index abfc3fd5..59a28f0e 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala @@ -3,60 +3,64 @@ package io.appthreat.x2cpg.passes.controlflow.codepencegraph import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{ - Call, - ControlStructure, - Identifier, - JumpTarget, - Literal, - Method, - MethodRef, - Unknown + Call, + ControlStructure, + Identifier, + JumpTarget, + Literal, + Method, + MethodRef, + Unknown } import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ -import io.appthreat.x2cpg.passes.controlflow.cfgdominator.{CfgDominatorFrontier, ReverseCpgCfgAdapter} +import io.shiftleft.semanticcpg.language.* +import io.appthreat.x2cpg.passes.controlflow.cfgdominator.{ + CfgDominatorFrontier, + ReverseCpgCfgAdapter +} import org.slf4j.{Logger, LoggerFactory} /** This pass has ContainsEdgePass and CfgDominatorPass as prerequisites. */ -class CdgPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg) { - import CdgPass.logger - - override def generateParts(): Array[Method] = cpg.method.toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = { - - val dominanceFrontier = new CfgDominatorFrontier(new ReverseCpgCfgAdapter, new CpgPostDomTreeAdapter) - - val cfgNodes = method._containsOut.toList - val postDomFrontiers = dominanceFrontier.calculate(method :: cfgNodes) - - postDomFrontiers.foreach { case (node, postDomFrontierNodes) => - postDomFrontierNodes.foreach { - case postDomFrontierNode @ (_: Literal | _: Identifier | _: Call | _: MethodRef | _: Unknown | - _: ControlStructure | _: JumpTarget) => - dstGraph.addEdge(postDomFrontierNode, node, EdgeTypes.CDG) - case postDomFrontierNode => - val nodeLabel = postDomFrontierNode.label - val containsIn = postDomFrontierNode._containsIn - if (containsIn == null || !containsIn.hasNext) { - logger.debug(s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG.") - } else { - val method = containsIn.next() - logger.debug( - s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + - s" Method: ${method match { - case m: Method => m.fullName; - case other => other.label - }}" + - s" number of outgoing CFG edges from $nodeLabel node: ${postDomFrontierNode._cfgOut.size}" - ) - } - } - } - } -} +class CdgPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg): + import CdgPass.logger -object CdgPass { - private val logger: Logger = LoggerFactory.getLogger(classOf[CdgPass]) -} + override def generateParts(): Array[Method] = cpg.method.toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + + val dominanceFrontier = + new CfgDominatorFrontier(new ReverseCpgCfgAdapter, new CpgPostDomTreeAdapter) + + val cfgNodes = method._containsOut.toList + val postDomFrontiers = dominanceFrontier.calculate(method :: cfgNodes) + + postDomFrontiers.foreach { case (node, postDomFrontierNodes) => + postDomFrontierNodes.foreach { + case postDomFrontierNode @ (_: Literal | _: Identifier | _: Call | _: MethodRef | _: Unknown | + _: ControlStructure | _: JumpTarget) => + dstGraph.addEdge(postDomFrontierNode, node, EdgeTypes.CDG) + case postDomFrontierNode => + val nodeLabel = postDomFrontierNode.label + val containsIn = postDomFrontierNode._containsIn + if containsIn == null || !containsIn.hasNext then + logger.debug( + s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + ) + else + val method = containsIn.next() + logger.debug( + s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + + s" Method: ${method match + case m: Method => m.fullName; + case other => other.label + }" + + s" number of outgoing CFG edges from $nodeLabel node: ${postDomFrontierNode._cfgOut.size}" + ) + } + } + end runOnPart +end CdgPass + +object CdgPass: + private val logger: Logger = LoggerFactory.getLogger(classOf[CdgPass]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala index 7bffcd7f..873a40ef 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala @@ -3,9 +3,7 @@ package io.appthreat.x2cpg.passes.controlflow.codepencegraph import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import io.appthreat.x2cpg.passes.controlflow.cfgdominator.DomTreeAdapter -class CpgPostDomTreeAdapter extends DomTreeAdapter[StoredNode] { +class CpgPostDomTreeAdapter extends DomTreeAdapter[StoredNode]: - override def immediateDominator(cfgNode: StoredNode): Option[StoredNode] = - cfgNode._postDominateIn.nextOption() - -} + override def immediateDominator(cfgNode: StoredNode): Option[StoredNode] = + cfgNode._postDominateIn.nextOption() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala index b4bc9d95..0efed058 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala @@ -2,34 +2,26 @@ package io.appthreat.x2cpg.passes.frontend import io.shiftleft.codepropertygraph.generated.Cpg import io.shiftleft.codepropertygraph.generated.Languages -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -object Dereference { +object Dereference: - def apply(cpg: Cpg): Dereference = cpg.metaData.language.headOption match { - case Some(Languages.NEWC) => CDereference() - case _ => DefaultDereference() - } + def apply(cpg: Cpg): Dereference = cpg.metaData.language.headOption match + case Some(Languages.NEWC) => CDereference() + case _ => DefaultDereference() -} +sealed trait Dereference: -sealed trait Dereference { + def dereferenceTypeFullName(fullName: String): String - def dereferenceTypeFullName(fullName: String): String +case class CDereference() extends Dereference: -} + /** Types from C/C++ can be annotated with * to indicate being a reference. As our CPG schema + * currently lacks a separate field for that information the * is part of the type full name + * and needs to be removed when linking. + */ + override def dereferenceTypeFullName(fullName: String): String = fullName.replace("*", "") -case class CDereference() extends Dereference { +case class DefaultDereference() extends Dereference: - /** Types from C/C++ can be annotated with * to indicate being a reference. As our CPG schema currently lacks a - * separate field for that information the * is part of the type full name and needs to be removed when linking. - */ - override def dereferenceTypeFullName(fullName: String): String = fullName.replace("*", "") - -} - -case class DefaultDereference() extends Dereference { - - override def dereferenceTypeFullName(fullName: String): String = fullName - -} + override def dereferenceTypeFullName(fullName: String): String = fullName diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala index 3db94fef..7ba7a831 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala @@ -6,40 +6,32 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewMetaData, NewNamespace import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, NamespaceTraversal} -/** A pass that creates a MetaData node, specifying that this is a CPG for language, and a NamespaceBlock for anything - * that cannot be assigned to any other namespace. +/** A pass that creates a MetaData node, specifying that this is a CPG for language, and a + * NamespaceBlock for anything that cannot be assigned to any other namespace. */ -class MetaDataPass(cpg: Cpg, language: String, root: String) extends CpgPass(cpg) { - override def run(diffGraph: DiffGraphBuilder): Unit = { - def addMetaDataNode(diffGraph: DiffGraphBuilder): Unit = { - val absolutePathToRoot = File(root).path.toAbsolutePath.toString - val metaNode = NewMetaData().language(language).root(absolutePathToRoot).version("0.1") - diffGraph.addNode(metaNode) - } +class MetaDataPass(cpg: Cpg, language: String, root: String) extends CpgPass(cpg): + override def run(diffGraph: DiffGraphBuilder): Unit = + def addMetaDataNode(diffGraph: DiffGraphBuilder): Unit = + val absolutePathToRoot = File(root).path.toAbsolutePath.toString + val metaNode = NewMetaData().language(language).root(absolutePathToRoot).version("0.1") + diffGraph.addNode(metaNode) - def addAnyNamespaceBlock(diffGraph: DiffGraphBuilder): Unit = { - val node = NewNamespaceBlock() - .name(NamespaceTraversal.globalNamespaceName) - .fullName(MetaDataPass.getGlobalNamespaceBlockFullName(None)) - .filename(FileTraversal.UNKNOWN) - .order(1) - diffGraph.addNode(node) - } + def addAnyNamespaceBlock(diffGraph: DiffGraphBuilder): Unit = + val node = NewNamespaceBlock() + .name(NamespaceTraversal.globalNamespaceName) + .fullName(MetaDataPass.getGlobalNamespaceBlockFullName(None)) + .filename(FileTraversal.UNKNOWN) + .order(1) + diffGraph.addNode(node) - addMetaDataNode(diffGraph) - addAnyNamespaceBlock(diffGraph) - } -} + addMetaDataNode(diffGraph) + addAnyNamespaceBlock(diffGraph) -object MetaDataPass { +object MetaDataPass: - def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = { - fileNameOption match { - case Some(fileName) => - s"$fileName:${NamespaceTraversal.globalNamespaceName}" - case None => - NamespaceTraversal.globalNamespaceName - } - } - -} + def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = + fileNameOption match + case Some(fileName) => + s"$fileName:${NamespaceTraversal.globalNamespaceName}" + case None => + NamespaceTraversal.globalNamespaceName diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala index 7c92db59..a279ec5f 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala @@ -9,42 +9,43 @@ import scala.collection.concurrent.TrieMap /** Represents an identifier of some AST node at a specific scope. */ -abstract class SBKey(val identifier: String) { - - /** Convenience methods to convert a node to a [[SBKey]]. - * - * @param node - * the node to convert. - * @return - * the corresponding [[SBKey]] if the node is supported as a key variable - */ - def fromNode(node: AstNode): Option[SBKey] - -} - -object SBKey { - protected val logger: Logger = LoggerFactory.getLogger(getClass) - def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = { - Option(node match { - case n: Identifier => LocalVar(n.name) - case n: Local => LocalVar(n.name) - case n: Call => - CallAlias(n.name, n.argument.collectFirst { case x: Identifier if x.argumentIndex == 0 => x.name }) - case n: Method => CallAlias(n.name, Option("this")) - case n: MethodRef => CallAlias(n.code) - case n: FieldIdentifier => LocalVar(n.canonicalName) - case n: MethodParameterIn => LocalVar(n.name) - case _ => logger.debug(s"Local node of type ${node.label} is not supported in the type recovery pass."); null - }) - } - -} +abstract class SBKey(val identifier: String): + + /** Convenience methods to convert a node to a [[SBKey]]. + * + * @param node + * the node to convert. + * @return + * the corresponding [[SBKey]] if the node is supported as a key variable + */ + def fromNode(node: AstNode): Option[SBKey] + +object SBKey: + protected val logger: Logger = LoggerFactory.getLogger(getClass) + def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = + Option(node match + case n: Identifier => LocalVar(n.name) + case n: Local => LocalVar(n.name) + case n: Call => + CallAlias( + n.name, + n.argument.collectFirst { case x: Identifier if x.argumentIndex == 0 => x.name } + ) + case n: Method => CallAlias(n.name, Option("this")) + case n: MethodRef => CallAlias(n.code) + case n: FieldIdentifier => LocalVar(n.canonicalName) + case n: MethodParameterIn => LocalVar(n.name) + case _ => + logger.debug( + s"Local node of type ${node.label} is not supported in the type recovery pass." + ); null + ) +end SBKey /** Represents an identifier of some AST node at an intraprocedural scope. */ -sealed class LocalKey(identifier: String) extends SBKey(identifier) { - override def fromNode(node: AstNode): Option[SBKey] = SBKey.fromNodeToLocalKey(node) -} +sealed class LocalKey(identifier: String) extends SBKey(identifier): + override def fromNode(node: AstNode): Option[SBKey] = SBKey.fromNodeToLocalKey(node) /** A variable that holds data within an intraprocedural scope. */ @@ -56,102 +57,93 @@ case class CollectionVar(override val identifier: String, idx: String) extends L /** A name that refers to some kind of callee. */ -case class CallAlias(override val identifier: String, receiverName: Option[String] = None) extends LocalKey(identifier) +case class CallAlias(override val identifier: String, receiverName: Option[String] = None) + extends LocalKey(identifier) -/** A thread-safe symbol table that can represent multiple types per symbol. Each node in an AST gets converted to an - * [[SBKey]] which gives contextual information to identify an AST entity. Each value in this table represents a set of - * types that the key could be in a flow-insensitive manner. +/** A thread-safe symbol table that can represent multiple types per symbol. Each node in an AST + * gets converted to an [[SBKey]] which gives contextual information to identify an AST entity. + * Each value in this table represents a set of types that the key could be in a flow-insensitive + * manner. * - * The [[SymbolTable]] operates like a map with a few convenient methods that are designed for this structure's - * purpose. + * The [[SymbolTable]] operates like a map with a few convenient methods that are designed for this + * structure's purpose. */ -class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]) { +class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]): - private val table = TrieMap.empty[K, Set[String]] + private val table = TrieMap.empty[K, Set[String]] - /** The set limit is to bound the set of possible types, since by using dummy types we could have an unbounded number - * of permutations of various access paths - */ - private val setLimit = 10 + /** The set limit is to bound the set of possible types, since by using dummy types we could + * have an unbounded number of permutations of various access paths + */ + private val setLimit = 10 - private def coalesce(oldEntries: Set[String], newEntries: Set[String]): Set[String] = { - val allTypes = (oldEntries ++ newEntries).toSeq // convert to ordered set to make `take` work predictably - val (dummies, noDummies) = allTypes.partition(XTypeRecovery.isDummyType) - (noDummies ++ dummies).take(setLimit).toSet - } + private def coalesce(oldEntries: Set[String], newEntries: Set[String]): Set[String] = + val allTypes = + (oldEntries ++ newEntries).toSeq // convert to ordered set to make `take` work predictably + val (dummies, noDummies) = allTypes.partition(XTypeRecovery.isDummyType) + (noDummies ++ dummies).take(setLimit).toSet - def apply(sbKey: K): Set[String] = table(sbKey) + def apply(sbKey: K): Set[String] = table(sbKey) - def apply(node: AstNode): Set[String] = - keyFromNode(node) match { - case Some(key) => table(key) - case None => Set.empty - } + def apply(node: AstNode): Set[String] = + keyFromNode(node) match + case Some(key) => table(key) + case None => Set.empty - def from(sb: IterableOnce[(K, Set[String])]): SymbolTable[K] = { - table.addAll(sb); this - } + def from(sb: IterableOnce[(K, Set[String])]): SymbolTable[K] = + table.addAll(sb); this - def put(sbKey: K, typeFullNames: Set[String]): Set[String] = - if (typeFullNames.nonEmpty) { - val newEntry = coalesce(Set.empty, typeFullNames) - table.put(sbKey, newEntry) - newEntry - } else { - Set.empty - } + def put(sbKey: K, typeFullNames: Set[String]): Set[String] = + if typeFullNames.nonEmpty then + val newEntry = coalesce(Set.empty, typeFullNames) + table.put(sbKey, newEntry) + newEntry + else + Set.empty - def put(sbKey: K, typeFullName: String): Set[String] = - put(sbKey, Set(typeFullName)) + def put(sbKey: K, typeFullName: String): Set[String] = + put(sbKey, Set(typeFullName)) - def put(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match { - case Some(key) => put(key, typeFullNames) - case None => Set.empty - } + def put(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match + case Some(key) => put(key, typeFullNames) + case None => Set.empty - def append(node: AstNode, typeFullName: String): Set[String] = - append(node, Set(typeFullName)) + def append(node: AstNode, typeFullName: String): Set[String] = + append(node, Set(typeFullName)) - def append(node: K, typeFullName: String): Set[String] = - append(node, Set(typeFullName)) + def append(node: K, typeFullName: String): Set[String] = + append(node, Set(typeFullName)) - def append(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match { - case Some(key) => append(key, typeFullNames) - case None => Set.empty - } + def append(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match + case Some(key) => append(key, typeFullNames) + case None => Set.empty - def append(sbKey: K, typeFullNames: Set[String]): Set[String] = { - table.get(sbKey) match { - case Some(ts) if ts == typeFullNames => ts - case Some(ts) if typeFullNames.nonEmpty => put(sbKey, coalesce(ts, typeFullNames)) - case None if typeFullNames.nonEmpty => put(sbKey, coalesce(Set.empty, typeFullNames)) - case _ => Set.empty - } - } + def append(sbKey: K, typeFullNames: Set[String]): Set[String] = + table.get(sbKey) match + case Some(ts) if ts == typeFullNames => ts + case Some(ts) if typeFullNames.nonEmpty => put(sbKey, coalesce(ts, typeFullNames)) + case None if typeFullNames.nonEmpty => put(sbKey, coalesce(Set.empty, typeFullNames)) + case _ => Set.empty - def contains(sbKey: K): Boolean = table.contains(sbKey) + def contains(sbKey: K): Boolean = table.contains(sbKey) - def contains(node: AstNode): Boolean = keyFromNode(node) match { - case Some(key) => contains(key) - case None => false - } + def contains(node: AstNode): Boolean = keyFromNode(node) match + case Some(key) => contains(key) + case None => false - def get(sbKey: K): Set[String] = table.getOrElse(sbKey, Set.empty) + def get(sbKey: K): Set[String] = table.getOrElse(sbKey, Set.empty) - def get(node: AstNode): Set[String] = keyFromNode(node) match { - case Some(key) => get(key) - case None => Set.empty - } + def get(node: AstNode): Set[String] = keyFromNode(node) match + case Some(key) => get(key) + case None => Set.empty - def remove(sbKey: K): Set[String] = table.remove(sbKey).getOrElse(Set.empty) + def remove(sbKey: K): Set[String] = table.remove(sbKey).getOrElse(Set.empty) - def remove(node: AstNode): Set[String] = keyFromNode(node) match { - case Some(key) => remove(key) - case None => Set.empty - } + def remove(node: AstNode): Set[String] = keyFromNode(node) match + case Some(key) => remove(key) + case None => Set.empty - def view: MapView[K, Set[String]] = table.view + def view: MapView[K, Set[String]] = table.view - def clear(): Unit = table.clear() - -} + def clear(): Unit = table.clear() +end SymbolTable diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala index dad381eb..4e060c2c 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala @@ -4,85 +4,91 @@ import io.appthreat.x2cpg.passes.frontend.TypeNodePass.fullToShortName import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.NewType import io.shiftleft.passes.{KeyPool, CpgPass} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import scala.collection.mutable -/** Creates a `TYPE` node for each type in `usedTypes` as well as all inheritsFrom type names in the CPG +/** Creates a `TYPE` node for each type in `usedTypes` as well as all inheritsFrom type names in the + * CPG * - * Alternatively, set `getTypesFromCpg = true`. If this is set, the `registeredTypes` argument will be ignored. - * Instead, type nodes will be created for every unique `TYPE_FULL_NAME` value in the CPG. + * Alternatively, set `getTypesFromCpg = true`. If this is set, the `registeredTypes` argument will + * be ignored. Instead, type nodes will be created for every unique `TYPE_FULL_NAME` value in the + * CPG. */ -class TypeNodePass private (registeredTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool], getTypesFromCpg: Boolean) - extends CpgPass(cpg, "types", keyPool) { +class TypeNodePass private ( + registeredTypes: List[String], + cpg: Cpg, + keyPool: Option[KeyPool], + getTypesFromCpg: Boolean +) extends CpgPass(cpg, "types", keyPool): - private def getTypeDeclTypes(): mutable.Set[String] = { - val typeDeclTypes = mutable.Set[String]() - cpg.typeDecl.foreach { typeDecl => - typeDeclTypes += typeDecl.fullName - typeDeclTypes ++= typeDecl.inheritsFromTypeFullName - } - typeDeclTypes - } + private def getTypeDeclTypes(): mutable.Set[String] = + val typeDeclTypes = mutable.Set[String]() + cpg.typeDecl.foreach { typeDecl => + typeDeclTypes += typeDecl.fullName + typeDeclTypes ++= typeDecl.inheritsFromTypeFullName + } + typeDeclTypes - def getTypeFullNamesFromCpg(): Set[String] = { - cpg.all - .map(_.property(PropertyNames.TYPE_FULL_NAME)) - .filter(_ != null) - .map(_.toString) - .toSet - } + def getTypeFullNamesFromCpg(): Set[String] = + cpg.all + .map(_.property(PropertyNames.TYPE_FULL_NAME)) + .filter(_ != null) + .map(_.toString) + .toSet - override def run(diffGraph: DiffGraphBuilder): Unit = { - val typeFullNameValues = - if (getTypesFromCpg) - getTypeFullNamesFromCpg() - else - registeredTypes.toSet + override def run(diffGraph: DiffGraphBuilder): Unit = + val typeFullNameValues = + if getTypesFromCpg then + getTypeFullNamesFromCpg() + else + registeredTypes.toSet - val usedTypesSet = getTypeDeclTypes() ++ typeFullNameValues - usedTypesSet.remove("") - val usedTypes = usedTypesSet.filterInPlace(!_.endsWith(NamespaceTraversal.globalNamespaceName)).toArray.sorted + val usedTypesSet = getTypeDeclTypes() ++ typeFullNameValues + usedTypesSet.remove("") + val usedTypes = usedTypesSet.filterInPlace( + !_.endsWith(NamespaceTraversal.globalNamespaceName) + ).toArray.sorted - diffGraph.addNode( - NewType() - .name("ANY") - .fullName("ANY") - .typeDeclFullName("ANY") - ) + diffGraph.addNode( + NewType() + .name("ANY") + .fullName("ANY") + .typeDeclFullName("ANY") + ) - usedTypes.foreach { typeName => - val shortName = fullToShortName(typeName) - val node = NewType() - .name(shortName) - .fullName(typeName) - .typeDeclFullName(typeName) - diffGraph.addNode(node) - } - } -} + usedTypes.foreach { typeName => + val shortName = fullToShortName(typeName) + val node = NewType() + .name(shortName) + .fullName(typeName) + .typeDeclFullName(typeName) + diffGraph.addNode(node) + } + end run +end TypeNodePass -object TypeNodePass { - // Lambda typeDecl type names fit the structure - // `a.b.c.d.ClassName.lambda$method$name:returnType(paramTypes)` - // so this regex works by greedily matching the package and class names - // at the start and cutting off the matched group before the signature. - private val lambdaTypeRegex = raw".*\.(.*):.*\(.*\)".r +object TypeNodePass: + // Lambda typeDecl type names fit the structure + // `a.b.c.d.ClassName.lambda$method$name:returnType(paramTypes)` + // so this regex works by greedily matching the package and class names + // at the start and cutting off the matched group before the signature. + private val lambdaTypeRegex = raw".*\.(.*):.*\(.*\)".r - def withTypesFromCpg(cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(Nil, cpg, keyPool, getTypesFromCpg = true) - } + def withTypesFromCpg(cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = + new TypeNodePass(Nil, cpg, keyPool, getTypesFromCpg = true) - def withRegisteredTypes(registeredTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = { - new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) - } + def withRegisteredTypes( + registeredTypes: List[String], + cpg: Cpg, + keyPool: Option[KeyPool] = None + ): TypeNodePass = + new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) - def fullToShortName(typeName: String): String = { - typeName match { - case lambdaTypeRegex(methodName) => methodName - case _ => typeName.split('.').lastOption.getOrElse(typeName) - } - } -} + def fullToShortName(typeName: String): String = + typeName match + case lambdaTypeRegex(methodName) => methodName + case _ => typeName.split('.').lastOption.getOrElse(typeName) +end TypeNodePass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala index d64b7b46..6568f7c6 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala @@ -11,69 +11,60 @@ import org.slf4j.LoggerFactory import java.nio.file.Paths import scala.util.{Failure, Success, Try} -/** Scans for and inserts configuration files into the CPG. Relies on the MetaData's `ROOT` property to provide the path - * to scan. +/** Scans for and inserts configuration files into the CPG. Relies on the MetaData's `ROOT` property + * to provide the path to scan. */ -abstract class XConfigFileCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[File](cpg) { - - private val logger = LoggerFactory.getLogger(this.getClass) - - // File filters to override by the implementing class - protected val configFileFilters: List[File => Boolean] - - private val rootDir = cpg.metaData.root.headOption.getOrElse("") - - override def generateParts(): Array[File] = - if (rootDir.isBlank) { - logger.debug("Unable to recover project directory for configuration file pass.") - Array.empty - } else { - Try(File(rootDir)) match { - case Success(file) if file.isDirectory => - file.listRecursively - .filter(isConfigFile) - .toArray - - case Success(file) if isConfigFile(file) => - Array(file) - - case _ => Array.empty - } - } - - override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = { - Try(IOUtils.readEntireFile(file.path)) match { - case Success(content) => - val name = configFileName(file) - val configNode = NewConfigFile().name(name).content(content) - logger.debug(s"Adding config file $name") - diffGraph.addNode(configNode) - - case Failure(error) => - logger.debug(s"Unable to create config file node for ${file.canonicalPath}: $error") - } - } - - private def configFileName(configFile: File): String = { - Try(Paths.get(rootDir).toAbsolutePath) - .map(_.relativize(configFile.path.toAbsolutePath).toString) - .orElse(Try(configFile.pathAsString)) - .getOrElse(configFile.name) - } - - protected def extensionFilter(extension: String)(file: File): Boolean = { - file.extension.contains(extension) - } - - protected def pathEndFilter(pathEnd: String)(file: File): Boolean = { - file.canonicalPath.endsWith(pathEnd) - } - - protected def pathRegexFilter(pathRegex: String)(file: File): Boolean = { - file.canonicalPath.matches(pathRegex) - } - - private def isConfigFile(file: File): Boolean = { - configFileFilters.exists(predicate => predicate(file)) - } -} +abstract class XConfigFileCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[File](cpg): + + private val logger = LoggerFactory.getLogger(this.getClass) + + // File filters to override by the implementing class + protected val configFileFilters: List[File => Boolean] + + private val rootDir = cpg.metaData.root.headOption.getOrElse("") + + override def generateParts(): Array[File] = + if rootDir.isBlank then + logger.debug("Unable to recover project directory for configuration file pass.") + Array.empty + else + Try(File(rootDir)) match + case Success(file) if file.isDirectory => + file.listRecursively + .filter(isConfigFile) + .toArray + + case Success(file) if isConfigFile(file) => + Array(file) + + case _ => Array.empty + + override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = + Try(IOUtils.readEntireFile(file.path)) match + case Success(content) => + val name = configFileName(file) + val configNode = NewConfigFile().name(name).content(content) + logger.debug(s"Adding config file $name") + diffGraph.addNode(configNode) + + case Failure(error) => + logger.debug(s"Unable to create config file node for ${file.canonicalPath}: $error") + + private def configFileName(configFile: File): String = + Try(Paths.get(rootDir).toAbsolutePath) + .map(_.relativize(configFile.path.toAbsolutePath).toString) + .orElse(Try(configFile.pathAsString)) + .getOrElse(configFile.name) + + protected def extensionFilter(extension: String)(file: File): Boolean = + file.extension.contains(extension) + + protected def pathEndFilter(pathEnd: String)(file: File): Boolean = + file.canonicalPath.endsWith(pathEnd) + + protected def pathRegexFilter(pathRegex: String)(file: File): Boolean = + file.canonicalPath.matches(pathRegex) + + private def isConfigFile(file: File): Boolean = + configFileFilters.exists(predicate => predicate(file)) +end XConfigFileCreationPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala index d7fe2032..d803ca86 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala @@ -10,118 +10,121 @@ import org.slf4j.{Logger, LoggerFactory} import java.io.File as JFile -abstract class XImportResolverPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Import](cpg) { - - protected val logger: Logger = LoggerFactory.getLogger(this.getClass) - protected val codeRoot: String = cpg.metaData.root.headOption.getOrElse(JFile.separator) - - override def generateParts(): Array[Import] = cpg.imports.toArray - - override def runOnPart(builder: DiffGraphBuilder, part: Import): Unit = for { - call <- part.call - fileName = call.file.name.headOption.getOrElse("").stripPrefix(codeRoot) - importedAs <- part.importedAs - importedEntity <- part.importedEntity - } { - optionalResolveImport(fileName, call, importedEntity, importedAs, builder) - } - - protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit - - protected def resolvedImportToTag(x: ResolvedImport, importCall: Call, diffGraph: DiffGraphBuilder): Unit = - importCall.start.newTagNodePair(x.label, x.serialize).store()(diffGraph) - -} - -object ImportsPass { - - sealed trait ResolvedImport { - def label: String - - def serialize: String - } - - implicit class TagToResolvedImportExt(traversal: Iterator[Tag]) { - def toResolvedImport: Iterator[ResolvedImport] = - traversal.flatMap(ResolvedImport.tagToResolvedImport) - } - - object ResolvedImport { - - val RESOLVED_METHOD = "RESOLVED_METHOD" - val RESOLVED_TYPE_DECL = "RESOLVED_TYPE_DECL" - val RESOLVED_MEMBER = "RESOLVED_MEMBER" - val UNKNOWN_METHOD = "UNKNOWN_METHOD" - val UNKNOWN_TYPE_DECL = "UNKNOWN_TYPE_DECL" - val UNKNOWN_IMPORT = "UNKNOWN_IMPORT" - - val OPT_FULL_NAME = "FULL_NAME" - val OPT_ALIAS = "ALIAS" - val OPT_RECEIVER = "RECEIVER" - val OPT_BASE_PATH = "BASE_PATH" - val OPT_NAME = "NAME" - - def tagToResolvedImport(tag: Tag): Option[ResolvedImport] = Option(tag.name match { - case RESOLVED_METHOD => - val opts = valueToOptions(tag.value) - ResolvedMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value) - case RESOLVED_MEMBER => - val opts = valueToOptions(tag.value) - ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME)) - case UNKNOWN_METHOD => - val opts = valueToOptions(tag.value) - UnknownMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value) - case UNKNOWN_IMPORT => UnknownImport(tag.value) - case _ => null - }) - - private def valueToOptions(x: String): Map[String, String] = - x.split(',').grouped(2).map(xs => xs(0) -> xs(1)).toMap - } - - case class ResolvedMethod( - fullName: String, - alias: String, - receiver: Option[String] = None, - override val label: String = RESOLVED_METHOD - ) extends ResolvedImport { - override def serialize: String = - s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => s",$OPT_RECEIVER,$r").getOrElse("") - } - - case class ResolvedTypeDecl(fullName: String, override val label: String = RESOLVED_TYPE_DECL) - extends ResolvedImport { - override def serialize: String = fullName - } - - case class ResolvedMember(basePath: String, memberName: String, override val label: String = RESOLVED_MEMBER) - extends ResolvedImport { - override def serialize: String = s"$OPT_BASE_PATH,$basePath,$OPT_NAME,$memberName" - } - - case class UnknownMethod( - fullName: String, - alias: String, - receiver: Option[String] = None, - override val label: String = UNKNOWN_METHOD - ) extends ResolvedImport { - override def serialize: String = - s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => s",$OPT_RECEIVER,$r").getOrElse("") - } - - case class UnknownTypeDecl(fullName: String, override val label: String = UNKNOWN_TYPE_DECL) extends ResolvedImport { - override def serialize: String = fullName - } - - case class UnknownImport(path: String, override val label: String = UNKNOWN_IMPORT) extends ResolvedImport { - override def serialize: String = path - } -} +abstract class XImportResolverPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Import](cpg): + + protected val logger: Logger = LoggerFactory.getLogger(this.getClass) + protected val codeRoot: String = cpg.metaData.root.headOption.getOrElse(JFile.separator) + + override def generateParts(): Array[Import] = cpg.imports.toArray + + override def runOnPart(builder: DiffGraphBuilder, part: Import): Unit = for + call <- part.call + fileName = call.file.name.headOption.getOrElse("").stripPrefix(codeRoot) + importedAs <- part.importedAs + importedEntity <- part.importedEntity + do + optionalResolveImport(fileName, call, importedEntity, importedAs, builder) + + protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit + + protected def resolvedImportToTag( + x: ResolvedImport, + importCall: Call, + diffGraph: DiffGraphBuilder + ): Unit = + importCall.start.newTagNodePair(x.label, x.serialize).store()(diffGraph) +end XImportResolverPass + +object ImportsPass: + + sealed trait ResolvedImport: + def label: String + + def serialize: String + + implicit class TagToResolvedImportExt(traversal: Iterator[Tag]): + def toResolvedImport: Iterator[ResolvedImport] = + traversal.flatMap(ResolvedImport.tagToResolvedImport) + + object ResolvedImport: + + val RESOLVED_METHOD = "RESOLVED_METHOD" + val RESOLVED_TYPE_DECL = "RESOLVED_TYPE_DECL" + val RESOLVED_MEMBER = "RESOLVED_MEMBER" + val UNKNOWN_METHOD = "UNKNOWN_METHOD" + val UNKNOWN_TYPE_DECL = "UNKNOWN_TYPE_DECL" + val UNKNOWN_IMPORT = "UNKNOWN_IMPORT" + + val OPT_FULL_NAME = "FULL_NAME" + val OPT_ALIAS = "ALIAS" + val OPT_RECEIVER = "RECEIVER" + val OPT_BASE_PATH = "BASE_PATH" + val OPT_NAME = "NAME" + + def tagToResolvedImport(tag: Tag): Option[ResolvedImport] = Option(tag.name match + case RESOLVED_METHOD => + val opts = valueToOptions(tag.value) + ResolvedMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) + case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value) + case RESOLVED_MEMBER => + val opts = valueToOptions(tag.value) + ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME)) + case UNKNOWN_METHOD => + val opts = valueToOptions(tag.value) + UnknownMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) + case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value) + case UNKNOWN_IMPORT => UnknownImport(tag.value) + case _ => null + ) + + private def valueToOptions(x: String): Map[String, String] = + x.split(',').grouped(2).map(xs => xs(0) -> xs(1)).toMap + end ResolvedImport + + case class ResolvedMethod( + fullName: String, + alias: String, + receiver: Option[String] = None, + override val label: String = RESOLVED_METHOD + ) extends ResolvedImport: + override def serialize: String = + s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => + s",$OPT_RECEIVER,$r" + ).getOrElse("") + + case class ResolvedTypeDecl(fullName: String, override val label: String = RESOLVED_TYPE_DECL) + extends ResolvedImport: + override def serialize: String = fullName + + case class ResolvedMember( + basePath: String, + memberName: String, + override val label: String = RESOLVED_MEMBER + ) extends ResolvedImport: + override def serialize: String = s"$OPT_BASE_PATH,$basePath,$OPT_NAME,$memberName" + + case class UnknownMethod( + fullName: String, + alias: String, + receiver: Option[String] = None, + override val label: String = UNKNOWN_METHOD + ) extends ResolvedImport: + override def serialize: String = + s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => + s",$OPT_RECEIVER,$r" + ).getOrElse("") + + case class UnknownTypeDecl(fullName: String, override val label: String = UNKNOWN_TYPE_DECL) + extends ResolvedImport: + override def serialize: String = fullName + + case class UnknownImport(path: String, override val label: String = UNKNOWN_IMPORT) + extends ResolvedImport: + override def serialize: String = path +end ImportsPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala index 97f1daf0..fdf7fe14 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala @@ -4,27 +4,24 @@ import io.appthreat.x2cpg.Imports.createImportNodeAndLink import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.passes.ConcurrentWriterCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment -abstract class XImportsPass(cpg: Cpg) extends ConcurrentWriterCpgPass[(Call, Assignment)](cpg) { +abstract class XImportsPass(cpg: Cpg) extends ConcurrentWriterCpgPass[(Call, Assignment)](cpg): - protected val importCallName: String + protected val importCallName: String - override def generateParts(): Array[(Call, Assignment)] = cpg - .call(importCallName) - .flatMap(importCallToPart) - .toArray + override def generateParts(): Array[(Call, Assignment)] = cpg + .call(importCallName) + .flatMap(importCallToPart) + .toArray - protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] + protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] - override def runOnPart(diffGraph: DiffGraphBuilder, part: (Call, Assignment)): Unit = { - val (call, assignment) = part - val importedEntity = importedEntityFromCall(call) - val importedAs = assignment.target.code - createImportNodeAndLink(importedEntity, importedAs, Some(call), diffGraph) - } + override def runOnPart(diffGraph: DiffGraphBuilder, part: (Call, Assignment)): Unit = + val (call, assignment) = part + val importedEntity = importedEntityFromCall(call) + val importedAs = assignment.target.code + createImportNodeAndLink(importedEntity, importedAs, Some(call), diffGraph) - protected def importedEntityFromCall(call: Call): String - -} + protected def importedEntityFromCall(call: Call): String diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala index fc9e3552..79e338e2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala @@ -2,141 +2,151 @@ package io.appthreat.x2cpg.passes.frontend import io.appthreat.x2cpg.passes.base.TypeDeclStubCreator import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, PropertyNames} import io.shiftleft.passes.ForkJoinParallelCpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import java.io.File import java.nio.file.Paths import java.util.regex.{Matcher, Pattern} -/** Using some basic heuristics, will try to resolve type full names from types found within the CPG. Requires - * ImportPass as a pre-requisite. +/** Using some basic heuristics, will try to resolve type full names from types found within the + * CPG. Requires ImportPass as a pre-requisite. */ -abstract class XInheritanceFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[TypeDecl](cpg) { - - protected val pathSep: Char = '.' - protected val fileModuleSep: Char = ':' - protected val moduleName: String - protected val fileExt: String - - private val relativePathPattern = Pattern.compile("^[.]+/?.*") - - override def generateParts(): Array[TypeDecl] = - cpg.typeDecl - .filterNot(t => inheritsNothingOfInterest(t.inheritsFromTypeFullName)) - .toArray - - override def runOnPart(builder: DiffGraphBuilder, source: TypeDecl): Unit = { - val resolvedTypeDecls = resolveInheritedTypeFullName(source, builder) - if (resolvedTypeDecls.nonEmpty) { - val fullNames = resolvedTypeDecls.map(_.fullName) - builder.setNodeProperty(source, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, fullNames) - cpg.typ.fullNameExact(fullNames: _*).foreach(tgt => builder.addEdge(source, tgt, EdgeTypes.INHERITS_FROM)) - } - } - - protected def inheritsNothingOfInterest(inheritedTypes: Seq[String]): Boolean = - inheritedTypes == Seq("ANY") || inheritedTypes == Seq("object") || inheritedTypes.isEmpty - - private def extractTypeDeclFromNode(node: AstNode): Option[String] = node match { - case x: Call if x.isCallForImportOut.nonEmpty => - x.isCallForImportOut.importedEntity.map { - case imp if relativePathPattern.matcher(imp).matches() => - imp.split(pathSep).toList match { - case head :: next => - (Paths.get(head).normalize().toString +: next).mkString(pathSep.toString) - case Nil => - Paths.get(imp).normalize().toString - } - case imp => imp - }.headOption - case x: TypeDecl => Option(x.fullName) - case _ => None - } - - protected def resolveInheritedTypeFullName(td: TypeDecl, builder: DiffGraphBuilder): Seq[TypeDeclBase] = { - val callsOfInterest = td.file.method.flatMap(_._callViaContainsOut) - val typeDeclsOfInterest = td.file.typeDecl - val qualifiedNamesInScope = (callsOfInterest ++ typeDeclsOfInterest) - .flatMap(extractTypeDeclFromNode) - .filterNot(_.endsWith(moduleName)) - .l - val matchersInScope = qualifiedNamesInScope.map { - case x if x.contains(pathSep) => - val splitName = x.split(pathSep) - splitName.last - case x => x - }.distinct - val validTypeDecls = cpg.typeDecl.nameExact(matchersInScope: _*).toList - val filteredTypes = validTypeDecls.filter(vt => td.inheritsFromTypeFullName.contains(vt.name)) - if (filteredTypes.isEmpty) { - // Usually in the case of inheriting external types - qualifiedNamesInScope - .flatMap { qn => - td.inheritsFromTypeFullName.find(t => ImportStringHandling.namesIntersect(qn, t, pathSep)) match { - case Some(path) => Option((qn, path)) - case None => None - } - } - .map { case (qualifiedName, inheritedNames) => xTypeFullName(qualifiedName, inheritedNames) } - .map { case (name, fullName) => createTypeStub(name, fullName, builder) } - } else { - filteredTypes - } - } - - private def createTypeStub(name: String, fullName: String, builder: DiffGraphBuilder): TypeDeclBase = - cpg.typeDecl.fullNameExact(fullName).headOption match { - case Some(typeDecl) => typeDecl - case None => - val typeDecl = TypeDeclStubCreator.createTypeDeclStub(name, fullName) - builder.addNode(typeDecl) - typeDecl - } - - /** Converts types in the form `foo.bar.Baz` to `foo/bar.js::program:Baz` and will result in a tuple of name and full - * name. - */ - protected def xTypeFullName(importedType: String, importedPath: String): (String, String) = { - val combinedPath = ImportStringHandling.combinedPath(importedType, importedPath, pathSep) - combinedPath.split(pathSep).lastOption match { - case Some(tName) => - ( - tName, - combinedPath - .stripSuffix(s"$pathSep$tName") - .replaceAll(s"${Pattern.quote(pathSep.toString)}", Matcher.quoteReplacement(File.separator)) + - Seq(s"$fileExt$fileModuleSep$moduleName", tName).mkString(pathSep.toString) - ) - case None => (combinedPath, combinedPath) - } - - } -} - -object ImportStringHandling { - def namesIntersect(a: String, b: String, pathSep: Char = '.'): Boolean = { - val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) - intersect.nonEmpty && (as.endsWith(intersect) || bs.endsWith(intersect)) - } - - private def splitAndIntersect(a: String, b: String, pathSep: Char = '.'): (Seq[String], Seq[String], Seq[String]) = { - val as = a.split(pathSep).toIndexedSeq - val bs = b.split(pathSep).toIndexedSeq - (as, bs, as.intersect(bs)) - } - - def combinedPath(importedType: String, importedPath: String, pathSep: Char = '.'): String = { - val (a, b) = - if (importedType.length > importedPath.length) - (importedType, importedPath) - else - (importedPath, importedType) - val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) - - if (a == importedPath) bs.diff(intersect).concat(as).mkString(pathSep.toString) - else as.diff(intersect).concat(bs).mkString(pathSep.toString) - } -} +abstract class XInheritanceFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[TypeDecl](cpg): + + protected val pathSep: Char = '.' + protected val fileModuleSep: Char = ':' + protected val moduleName: String + protected val fileExt: String + + private val relativePathPattern = Pattern.compile("^[.]+/?.*") + + override def generateParts(): Array[TypeDecl] = + cpg.typeDecl + .filterNot(t => inheritsNothingOfInterest(t.inheritsFromTypeFullName)) + .toArray + + override def runOnPart(builder: DiffGraphBuilder, source: TypeDecl): Unit = + val resolvedTypeDecls = resolveInheritedTypeFullName(source, builder) + if resolvedTypeDecls.nonEmpty then + val fullNames = resolvedTypeDecls.map(_.fullName) + builder.setNodeProperty(source, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, fullNames) + cpg.typ.fullNameExact(fullNames*).foreach(tgt => + builder.addEdge(source, tgt, EdgeTypes.INHERITS_FROM) + ) + + protected def inheritsNothingOfInterest(inheritedTypes: Seq[String]): Boolean = + inheritedTypes == Seq("ANY") || inheritedTypes == Seq("object") || inheritedTypes.isEmpty + + private def extractTypeDeclFromNode(node: AstNode): Option[String] = node match + case x: Call if x.isCallForImportOut.nonEmpty => + x.isCallForImportOut.importedEntity.map { + case imp if relativePathPattern.matcher(imp).matches() => + imp.split(pathSep).toList match + case head :: next => + (Paths.get(head).normalize().toString +: next).mkString( + pathSep.toString + ) + case Nil => + Paths.get(imp).normalize().toString + case imp => imp + }.headOption + case x: TypeDecl => Option(x.fullName) + case _ => None + + protected def resolveInheritedTypeFullName( + td: TypeDecl, + builder: DiffGraphBuilder + ): Seq[TypeDeclBase] = + val callsOfInterest = td.file.method.flatMap(_._callViaContainsOut) + val typeDeclsOfInterest = td.file.typeDecl + val qualifiedNamesInScope = (callsOfInterest ++ typeDeclsOfInterest) + .flatMap(extractTypeDeclFromNode) + .filterNot(_.endsWith(moduleName)) + .l + val matchersInScope = qualifiedNamesInScope.map { + case x if x.contains(pathSep) => + val splitName = x.split(pathSep) + splitName.last + case x => x + }.distinct + val validTypeDecls = cpg.typeDecl.nameExact(matchersInScope*).toList + val filteredTypes = + validTypeDecls.filter(vt => td.inheritsFromTypeFullName.contains(vt.name)) + if filteredTypes.isEmpty then + // Usually in the case of inheriting external types + qualifiedNamesInScope + .flatMap { qn => + td.inheritsFromTypeFullName.find(t => + ImportStringHandling.namesIntersect(qn, t, pathSep) + ) match + case Some(path) => Option((qn, path)) + case None => None + } + .map { case (qualifiedName, inheritedNames) => + xTypeFullName(qualifiedName, inheritedNames) + } + .map { case (name, fullName) => createTypeStub(name, fullName, builder) } + else + filteredTypes + end resolveInheritedTypeFullName + + private def createTypeStub( + name: String, + fullName: String, + builder: DiffGraphBuilder + ): TypeDeclBase = + cpg.typeDecl.fullNameExact(fullName).headOption match + case Some(typeDecl) => typeDecl + case None => + val typeDecl = TypeDeclStubCreator.createTypeDeclStub(name, fullName) + builder.addNode(typeDecl) + typeDecl + + /** Converts types in the form `foo.bar.Baz` to `foo/bar.js::program:Baz` and will result in a + * tuple of name and full name. + */ + protected def xTypeFullName(importedType: String, importedPath: String): (String, String) = + val combinedPath = ImportStringHandling.combinedPath(importedType, importedPath, pathSep) + combinedPath.split(pathSep).lastOption match + case Some(tName) => + ( + tName, + combinedPath + .stripSuffix(s"$pathSep$tName") + .replaceAll( + s"${Pattern.quote(pathSep.toString)}", + Matcher.quoteReplacement(File.separator) + ) + + Seq(s"$fileExt$fileModuleSep$moduleName", tName).mkString(pathSep.toString) + ) + case None => (combinedPath, combinedPath) +end XInheritanceFullNamePass + +object ImportStringHandling: + def namesIntersect(a: String, b: String, pathSep: Char = '.'): Boolean = + val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) + intersect.nonEmpty && (as.endsWith(intersect) || bs.endsWith(intersect)) + + private def splitAndIntersect( + a: String, + b: String, + pathSep: Char = '.' + ): (Seq[String], Seq[String], Seq[String]) = + val as = a.split(pathSep).toIndexedSeq + val bs = b.split(pathSep).toIndexedSeq + (as, bs, as.intersect(bs)) + + def combinedPath(importedType: String, importedPath: String, pathSep: Char = '.'): String = + val (a, b) = + if importedType.length > importedPath.length then + (importedType, importedPath) + else + (importedPath, importedType) + val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) + + if a == importedPath then bs.diff(intersect).concat(as).mkString(pathSep.toString) + else as.diff(intersect).concat(bs).mkString(pathSep.toString) +end ImportStringHandling diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala index edb31d7d..ea7fd4fc 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala @@ -13,171 +13,186 @@ import io.shiftleft.semanticcpg.language.NoResolve import java.util.regex.Pattern import scala.collection.mutable -/** Attempts to set the methodFullName and link to callees using the recovered type information from - * [[XTypeRecovery]]. Note that some methods may not be present as they could be external and have been dynamically - * discovered, thus the [[io.appthreat.x2cpg.passes.base.MethodStubCreator]] would have missed it. +/** Attempts to set the methodFullName and link to callees using the recovered type + * information from [[XTypeRecovery]]. Note that some methods may not be present as they could be + * external and have been dynamically discovered, thus the + * [[io.appthreat.x2cpg.passes.base.MethodStubCreator]] would have missed it. * * @param cpg * the target code property graph. */ -abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg) { - - implicit protected val resolver: NoResolve.type = NoResolve - private val fileNamePattern = Pattern.compile("^(.*(.py|.js|.rb)).*$") - protected val pathSep: Char = '.' - - protected def calls: Iterator[Call] = cpg.call - .nameNot(".*", ".*") - .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) - - protected def calleeNames(c: Call): Seq[String] = - c.dynamicTypeHintFullName.filterNot(_.equals("ANY")).distinct - - protected def callees(names: Seq[String]): List[Method] = cpg.method.fullNameExact(names: _*).toList - - override def run(builder: DiffGraphBuilder): Unit = linkCalls(builder) - - protected def linkCalls(builder: DiffGraphBuilder): Unit = { - val callerAndCallees = calls.map(call => (call, calleeNames(call))).toList - val methodMap = mapMethods(callerAndCallees, builder) - linkCallsToCallees(callerAndCallees, methodMap, builder) - linkSpeculativeNamespaceNodes(methodMap.view.values.collectAll[NewMethod], builder) - } - - protected def mapMethods( - callerAndCallees: List[(Call, Seq[String])], - builder: DiffGraphBuilder - ): Map[String, MethodBase] = { - val methodMap = mutable.HashMap.empty[String, MethodBase] - val newMethods = mutable.HashMap.empty[String, NewMethod] - callerAndCallees.foreach { case (call, methodNames) => - val ms = callees(methodNames) - if (ms.nonEmpty) { - ms.foreach { m => methodMap.put(m.fullName, m) } - } else { - val mNames = ms.map(_.fullName).toSet - methodNames - .filterNot(mNames.contains) - .map(fullName => newMethods.getOrElseUpdate(fullName, createMethodStub(fullName, call, builder))) - .foreach { m => methodMap.put(m.fullName, m) } - } - } - methodMap.toMap - } - - protected def linkCallsToCallees( - callerAndCallees: List[(Call, Seq[String])], - methodMap: Map[String, MethodBase], - builder: DiffGraphBuilder - ): Unit = { - // Link edges to method nodes - callerAndCallees.foreach { case (call, methodNames) => - methodNames - .flatMap(methodMap.get) - .filter(m => call.callee(NoResolve).fullNameExact(m.fullName).isEmpty) - .foreach { m => linkCallToCallee(call, m, builder) } - setCallees(call, methodNames, builder) - } - } - - def linkCallToCallee(call: Call, method: MethodBase, builder: DiffGraphBuilder): Unit = { - builder.addEdge(call, method, EdgeTypes.CALL) - method match { - case method: Method => - builder.setNodeProperty(call, PropertyNames.TYPE_FULL_NAME, method.methodReturn.typeFullName) - case _ => - } - } - - protected def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = { - val nonDummyTypes = methodNames.filterNot(isDummyType) - if (methodNames.sizeIs == 1) { - builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, methodNames.head) - builder.setNodeProperty( - call, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - call.dynamicTypeHintFullName.diff(methodNames) - ) - } else if (methodNames.sizeIs > 1 && methodNames != nonDummyTypes) { - setCallees(call, nonDummyTypes, builder) - } - } - - protected def createMethodStub(methodName: String, call: Call, builder: DiffGraphBuilder): NewMethod = { - // In the case of Python/JS we can use name info to check if, despite the method name might be incorrect, that we - // label the method correctly as internal by finding that the method should belong to an internal file - val matcher = fileNamePattern.matcher(methodName) - val basePath = cpg.metaData.root.head - val isExternal = if (matcher.matches()) { - val fileName = matcher.group(1) - cpg.file.nameExact(s"$basePath$fileName").isEmpty - } else { - true - } - val name = - if (methodName.contains(pathSep) && methodName.length > methodName.lastIndexOf(pathSep) + 1) - methodName.substring(methodName.lastIndexOf(pathSep) + 1) - else methodName - createMethodStub(name, methodName, call.argumentOut.size, isExternal, builder) - } - - /** Try to extract a type full name from the method full name, if one exists in the CPG then we are lucky and we use - * it, else we ignore for now. - */ - protected def createMethodStub( - name: String, - fullName: String, - argSize: Int, - isExternal: Boolean, - builder: DiffGraphBuilder - ): NewMethod = { - val nameIdx = fullName.lastIndexOf(name) - val default = (NodeTypes.NAMESPACE_BLOCK, XTypeHintCallLinker.namespace) - val (astParentType, astParentFullName) = - if (!fullName.isBlank && !fullName.startsWith(" 0) { - cpg.typeDecl - .fullNameExact(fullName.substring(0, nameIdx - 1)) - .map(t => t.label -> t.fullName) - .headOption - .getOrElse(default) - } else { - default - } - - MethodStubCreator - .createMethodStub( - name, - fullName, - "", - DispatchTypes.DYNAMIC_DISPATCH.name(), - argSize, - builder, - isExternal, - astParentType, - astParentFullName - ) - } - - /** Once we have connected methods that were speculatively generated and managed to correctly link to methods already - * in the CPG, we link the rest to the "speculativeMethods" namespace as a way to show that these may not actually - * exist. - */ - protected def linkSpeculativeNamespaceNodes(newMethods: IterableOnce[NewMethod], builder: DiffGraphBuilder): Unit = { - val speculativeNamespace = - NewNamespaceBlock().name(XTypeHintCallLinker.namespace).fullName(XTypeHintCallLinker.namespace) - - builder.addNode(speculativeNamespace) - newMethods.iterator - .filter(_.astParentFullName == XTypeHintCallLinker.namespace) - .foreach(m => builder.addEdge(speculativeNamespace, m, EdgeTypes.AST)) - } -} - -object XTypeHintCallLinker { - - /** The shared namespace for all methods generated from the type recovery that may not exist with this exact full name - * in reality. - */ - val namespace: String = "" - -} +abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg): + + implicit protected val resolver: NoResolve.type = NoResolve + private val fileNamePattern = Pattern.compile("^(.*(.py|.js|.rb)).*$") + protected val pathSep: Char = '.' + + protected def calls: Iterator[Call] = cpg.call + .nameNot(".*", ".*") + .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) + + protected def calleeNames(c: Call): Seq[String] = + c.dynamicTypeHintFullName.filterNot(_.equals("ANY")).distinct + + protected def callees(names: Seq[String]): List[Method] = + cpg.method.fullNameExact(names*).toList + + override def run(builder: DiffGraphBuilder): Unit = linkCalls(builder) + + protected def linkCalls(builder: DiffGraphBuilder): Unit = + val callerAndCallees = calls.map(call => (call, calleeNames(call))).toList + val methodMap = mapMethods(callerAndCallees, builder) + linkCallsToCallees(callerAndCallees, methodMap, builder) + linkSpeculativeNamespaceNodes(methodMap.view.values.collectAll[NewMethod], builder) + + protected def mapMethods( + callerAndCallees: List[(Call, Seq[String])], + builder: DiffGraphBuilder + ): Map[String, MethodBase] = + val methodMap = mutable.HashMap.empty[String, MethodBase] + val newMethods = mutable.HashMap.empty[String, NewMethod] + callerAndCallees.foreach { case (call, methodNames) => + val ms = callees(methodNames) + if ms.nonEmpty then + ms.foreach { m => methodMap.put(m.fullName, m) } + else + val mNames = ms.map(_.fullName).toSet + methodNames + .filterNot(mNames.contains) + .map(fullName => + newMethods.getOrElseUpdate( + fullName, + createMethodStub(fullName, call, builder) + ) + ) + .foreach { m => methodMap.put(m.fullName, m) } + } + methodMap.toMap + end mapMethods + + protected def linkCallsToCallees( + callerAndCallees: List[(Call, Seq[String])], + methodMap: Map[String, MethodBase], + builder: DiffGraphBuilder + ): Unit = + // Link edges to method nodes + callerAndCallees.foreach { case (call, methodNames) => + methodNames + .flatMap(methodMap.get) + .filter(m => call.callee(NoResolve).fullNameExact(m.fullName).isEmpty) + .foreach { m => linkCallToCallee(call, m, builder) } + setCallees(call, methodNames, builder) + } + + def linkCallToCallee(call: Call, method: MethodBase, builder: DiffGraphBuilder): Unit = + builder.addEdge(call, method, EdgeTypes.CALL) + method match + case method: Method => + builder.setNodeProperty( + call, + PropertyNames.TYPE_FULL_NAME, + method.methodReturn.typeFullName + ) + case _ => + + protected def setCallees( + call: Call, + methodNames: Seq[String], + builder: DiffGraphBuilder + ): Unit = + val nonDummyTypes = methodNames.filterNot(isDummyType) + if methodNames.sizeIs == 1 then + builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, methodNames.head) + builder.setNodeProperty( + call, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + call.dynamicTypeHintFullName.diff(methodNames) + ) + else if methodNames.sizeIs > 1 && methodNames != nonDummyTypes then + setCallees(call, nonDummyTypes, builder) + + protected def createMethodStub( + methodName: String, + call: Call, + builder: DiffGraphBuilder + ): NewMethod = + // In the case of Python/JS we can use name info to check if, despite the method name might be incorrect, that we + // label the method correctly as internal by finding that the method should belong to an internal file + val matcher = fileNamePattern.matcher(methodName) + val basePath = cpg.metaData.root.head + val isExternal = if matcher.matches() then + val fileName = matcher.group(1) + cpg.file.nameExact(s"$basePath$fileName").isEmpty + else + true + val name = + if methodName.contains(pathSep) && methodName.length > methodName.lastIndexOf( + pathSep + ) + 1 + then + methodName.substring(methodName.lastIndexOf(pathSep) + 1) + else methodName + createMethodStub(name, methodName, call.argumentOut.size, isExternal, builder) + end createMethodStub + + /** Try to extract a type full name from the method full name, if one exists in the CPG then we + * are lucky and we use it, else we ignore for now. + */ + protected def createMethodStub( + name: String, + fullName: String, + argSize: Int, + isExternal: Boolean, + builder: DiffGraphBuilder + ): NewMethod = + val nameIdx = fullName.lastIndexOf(name) + val default = (NodeTypes.NAMESPACE_BLOCK, XTypeHintCallLinker.namespace) + val (astParentType, astParentFullName) = + if !fullName.isBlank && !fullName.startsWith(" 0 then + cpg.typeDecl + .fullNameExact(fullName.substring(0, nameIdx - 1)) + .map(t => t.label -> t.fullName) + .headOption + .getOrElse(default) + else + default + + MethodStubCreator + .createMethodStub( + name, + fullName, + "", + DispatchTypes.DYNAMIC_DISPATCH.name(), + argSize, + builder, + isExternal, + astParentType, + astParentFullName + ) + end createMethodStub + + /** Once we have connected methods that were speculatively generated and managed to correctly + * link to methods already in the CPG, we link the rest to the "speculativeMethods" namespace + * as a way to show that these may not actually exist. + */ + protected def linkSpeculativeNamespaceNodes( + newMethods: IterableOnce[NewMethod], + builder: DiffGraphBuilder + ): Unit = + val speculativeNamespace = + NewNamespaceBlock().name(XTypeHintCallLinker.namespace).fullName( + XTypeHintCallLinker.namespace + ) + + builder.addNode(speculativeNamespace) + newMethods.iterator + .filter(_.astParentFullName == XTypeHintCallLinker.namespace) + .foreach(m => builder.addEdge(speculativeNamespace, m, EdgeTypes.AST)) +end XTypeHintCallLinker + +object XTypeHintCallLinker: + + /** The shared namespace for all methods generated from the type recovery that may not exist + * with this exact full name in reality. + */ + val namespace: String = "" diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala index ac2dae71..8ef36bea 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala @@ -43,18 +43,18 @@ case class XTypeRecoveryState( isFieldCache: TrieMap[Long, Boolean] = TrieMap.empty[Long, Boolean], changesWereMade: AtomicBoolean = new AtomicBoolean(false), stopEarly: AtomicBoolean -) { - lazy val isFinalIteration: Boolean = currentIteration == config.iterations - 1 +): + lazy val isFinalIteration: Boolean = currentIteration == config.iterations - 1 - lazy val isFirstIteration: Boolean = currentIteration == 0 + lazy val isFirstIteration: Boolean = currentIteration == 0 - def clear(): Unit = isFieldCache.clear() + def clear(): Unit = isFieldCache.clear() +end XTypeRecoveryState -} - -/** In order to propagate types across compilation units, but avoid the poor scalability of a fixed-point algorithm, the - * number of iterations can be configured using the iterations parameter. Note that iterations < 2 will not provide any - * interprocedural type recovery capabilities. +/** In order to propagate types across compilation units, but avoid the poor scalability of a + * fixed-point algorithm, the number of iterations can be configured using the iterations + * parameter. Note that iterations < 2 will not provide any interprocedural type recovery + * capabilities. * @param cpg * the CPG to recovery types for. * @@ -64,144 +64,145 @@ case class XTypeRecoveryState( abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig() -) extends CpgPass(cpg) { - - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = - if (config.iterations > 0) { - val stopEarly = new AtomicBoolean(false) - val state = XTypeRecoveryState(config, stopEarly = stopEarly) - try { - Iterator.from(0).takeWhile(_ < config.iterations).foreach { i => - val newState = state.copy(currentIteration = i) - generateRecoveryPass(newState).createAndApply() - } - // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values - if (stopEarly.get() && config.enabledDummyTypes) - generateRecoveryPass(state.copy(currentIteration = config.iterations - 1)).createAndApply() - } finally { - state.clear() - } - } - - protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[CompilationUnitType] - -} - -trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]] { this: R => - - var disableDummyTypes: Boolean = false - var typePropagationIterations: Int = 2 - - def withDisableDummyTypes(value: Boolean): R = { - this.disableDummyTypes = value - this - } - - def withTypePropagationIterations(value: Int): R = { - typePropagationIterations = value - this - } - -} - -/** Based on a flow-insensitive static single-assignment symbol-table-style approach. This pass aims to be fast and - * deterministic and does not try to converge to some fixed point but rather iterates a fixed number of times. This - * will help recover:
  1. Imported call signatures from external dependencies
  2. Dynamic type hints for - * mutable variables in a compilation unit.
+) extends CpgPass(cpg): + + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = + if config.iterations > 0 then + val stopEarly = new AtomicBoolean(false) + val state = XTypeRecoveryState(config, stopEarly = stopEarly) + try + Iterator.from(0).takeWhile(_ < config.iterations).foreach { i => + val newState = state.copy(currentIteration = i) + generateRecoveryPass(newState).createAndApply() + } + // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values + if stopEarly.get() && config.enabledDummyTypes then + generateRecoveryPass( + state.copy(currentIteration = config.iterations - 1) + ).createAndApply() + finally + state.clear() + + protected def generateRecoveryPass(state: XTypeRecoveryState) + : XTypeRecovery[CompilationUnitType] +end XTypeRecoveryPass + +trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]]: + this: R => + + var disableDummyTypes: Boolean = false + var typePropagationIterations: Int = 2 + + def withDisableDummyTypes(value: Boolean): R = + this.disableDummyTypes = value + this + + def withTypePropagationIterations(value: Int): R = + typePropagationIterations = value + this + +/** Based on a flow-insensitive static single-assignment symbol-table-style approach. This pass aims + * to be fast and deterministic and does not try to converge to some fixed point but rather + * iterates a fixed number of times. This will help recover:
  1. Imported call signatures from + * external dependencies
  2. Dynamic type hints for mutable variables in a compilation + * unit.
* - * The algorithm flows roughly as follows:
  1. Scan for method signatures of methods for each compilation unit, - * either by internally defined methods or by reading import signatures. This includes looking for aliases, e.g. import - * foo as bar.
  2. (Optionally) Prune these method signatures by checking their validity against the - * CPG.
  3. Visit assignments to populate where variables are assigned a value to extrapolate its type. Store these - * values in a local symbol table. If a field is assigned a value, store this in the global table
  4. Find - * instances of where these fields and variables are used and update their type information.
  5. If this variable - * is the receiver of a call, make sure to set the type of the call accordingly.
+ * The algorithm flows roughly as follows:
  1. Scan for method signatures of methods for each + * compilation unit, either by internally defined methods or by reading import signatures. This + * includes looking for aliases, e.g. import foo as bar.
  2. (Optionally) Prune these method + * signatures by checking their validity against the CPG.
  3. Visit assignments to populate + * where variables are assigned a value to extrapolate its type. Store these values in a local + * symbol table. If a field is assigned a value, store this in the global table
  4. Find + * instances of where these fields and variables are used and update their type + * information.
  5. If this variable is the receiver of a call, make sure to set the type of + * the call accordingly.
* - * The symbol tables use the [[SymbolTable]] class to track possible type information.
Note: Local symbols - * are cleared once a compilation unit is complete. This is to keep memory usage down while maximizing - * concurrency. + * The symbol tables use the [[SymbolTable]] class to track possible type information.
+ * Note: Local symbols are cleared once a compilation unit is complete. This is to keep + * memory usage down while maximizing concurrency. * * @param cpg * the CPG to recovery types for. * @tparam CompilationUnitType * the AstNode type used to represent a compilation unit of the language. */ -abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, state: XTypeRecoveryState) extends CpgPass(cpg) { - - override def run(builder: DiffGraphBuilder): Unit = { - val changesWereMade = compilationUnit - .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) - .map(_.get) - .reduceOption((a, b) => a || b) - .getOrElse(false) - if (!changesWereMade) state.stopEarly.set(true) - } - - /** @return - * the compilation units as per how the language is compiled. e.g. file. - */ - def compilationUnit: Iterator[CompilationUnitType] - - /** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given parameters. - * - * @param unit - * the compilation unit. - * @param builder - * the graph builder. - * @return - * a forkable [[RecoverForXCompilationUnit]] task. - */ - def generateRecoveryForCompilationUnitTask( - unit: CompilationUnitType, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[CompilationUnitType] - -} - -object XTypeRecovery { - - private val logger = LoggerFactory.getLogger(getClass) - - val DummyReturnType = "" - val DummyMemberLoad = "" - val DummyIndexAccess = "" - private lazy val DummyTokens: Set[String] = Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) - - def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = - s"$prefix$sep$DummyMemberLoad($memberName)" - - /** Scans the type for placeholder/dummy types. - */ - def isDummyType(typ: String): Boolean = DummyTokens.exists(typ.contains) - - /** Parser options for languages implementing this pass. - */ - def parserOptions[R <: X2CpgConfig[R] with TypeRecoveryParserConfig[R]]: OParser[_, R] = { - val builder = OParser.builder[R] - import builder.* - OParser.sequence( - opt[Unit]("no-dummyTypes") - .hidden() - .action((_, c) => c.withDisableDummyTypes(true)) - .text("disable generation of dummy types during type propagation"), - opt[Int]("type-prop-iterations") - .hidden() - .action((x, c) => c.withTypePropagationIterations(x)) - .text("maximum iterations of type propagation") - .validate { x => - if (x <= 0) { - logger.debug("Disabling type propagation as the given iteration count is <= 0") - } else if (x == 1) { - logger.debug("Intra-procedural type propagation enabled") - } else if (x > 5) { - logger.debug(s"Large iteration count of $x will take a while to terminate") - } - success - } - ) - } - -} +abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, state: XTypeRecoveryState) + extends CpgPass(cpg): + + override def run(builder: DiffGraphBuilder): Unit = + val changesWereMade = compilationUnit + .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) + .map(_.get) + .reduceOption((a, b) => a || b) + .getOrElse(false) + if !changesWereMade then state.stopEarly.set(true) + + /** @return + * the compilation units as per how the language is compiled. e.g. file. + */ + def compilationUnit: Iterator[CompilationUnitType] + + /** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given + * parameters. + * + * @param unit + * the compilation unit. + * @param builder + * the graph builder. + * @return + * a forkable [[RecoverForXCompilationUnit]] task. + */ + def generateRecoveryForCompilationUnitTask( + unit: CompilationUnitType, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[CompilationUnitType] +end XTypeRecovery + +object XTypeRecovery: + + private val logger = LoggerFactory.getLogger(getClass) + + val DummyReturnType = "" + val DummyMemberLoad = "" + val DummyIndexAccess = "" + private lazy val DummyTokens: Set[String] = + Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) + + def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = + s"$prefix$sep$DummyMemberLoad($memberName)" + + /** Scans the type for placeholder/dummy types. + */ + def isDummyType(typ: String): Boolean = DummyTokens.exists(typ.contains) + + /** Parser options for languages implementing this pass. + */ + def parserOptions[R <: X2CpgConfig[R] with TypeRecoveryParserConfig[R]]: OParser[?, R] = + val builder = OParser.builder[R] + import builder.* + OParser.sequence( + opt[Unit]("no-dummyTypes") + .hidden() + .action((_, c) => c.withDisableDummyTypes(true)) + .text("disable generation of dummy types during type propagation"), + opt[Int]("type-prop-iterations") + .hidden() + .action((x, c) => c.withTypePropagationIterations(x)) + .text("maximum iterations of type propagation") + .validate { x => + if x <= 0 then + logger.debug( + "Disabling type propagation as the given iteration count is <= 0" + ) + else if x == 1 then + logger.debug("Intra-procedural type propagation enabled") + else if x > 5 then + logger.debug(s"Large iteration count of $x will take a while to terminate") + success + } + ) + end parserOptions +end XTypeRecovery /** Performs type recovery from the root of a compilation unit level * @@ -219,979 +220,1049 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( cu: CompilationUnitType, builder: DiffGraphBuilder, state: XTypeRecoveryState -) extends RecursiveTask[Boolean] { - - protected val logger: Logger = LoggerFactory.getLogger(getClass) - - /** Stores type information for local structures that live within this compilation unit, e.g. local variables. - */ - protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) - - /** The root of the target codebase. - */ - protected val codeRoot: String = cpg.metaData.root.headOption.getOrElse("") + java.io.File.separator - - /** The delimiter used to separate methods/functions in qualified names. - */ - protected val pathSep = '.' - - /** New node tracking set. - */ - protected val addedNodes = mutable.HashSet.empty[String] - - /** For tracking members and the type operations that need to be performed. Since these are mostly out of scope - * locally it helps to track these separately. - * - * // TODO: Potentially a new use for a global table or modification to the symbol table? - */ - protected val newTypesForMembers = mutable.HashMap.empty[Member, Set[String]] - - /** Provides an entrypoint to add known symbols and their possible types. - */ - protected def prepopulateSymbolTable(): Unit = { - (cu.ast.isIdentifier ++ cu.ast.isCall ++ cu.ast.isLocal ++ cu.ast.isParameter) - .filter(hasTypes) - .foreach(prepopulateSymbolTableEntry) - } - - protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match { - case x @ (_: Identifier | _: Local | _: MethodParameterIn) => symbolTable.append(x, x.getKnownTypes) - case x: Call => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) - case _ => - } - - protected def hasTypes(node: AstNode): Boolean = node match { - case x: Call if !x.methodFullName.startsWith("") => - !x.methodFullName.toLowerCase().matches("(|any)") - case x => x.getKnownTypes.nonEmpty - } - - protected def assignments: Iterator[Assignment] = cu match { - case x: File => - x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) - case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) - case x => x.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) - } - - protected def members: Iterator[Member] = cu.ast.isMember - - protected def returns: Iterator[Return] = cu match { - case x: File => x.method.flatMap(_._returnViaContainsOut) - case x: Method => x._returnViaContainsOut - case x => x.ast.isReturn - } - - protected def importNodes: Iterator[Import] = cu match { - case x: File => x.method.flatMap(_._callViaContainsOut).referencedImports - case x: Method => x.file.method.flatMap(_._callViaContainsOut).referencedImports - case x => x.ast.isCall.referencedImports - } - - override def compute(): Boolean = try { - // Set known aliases that point to imports for local and external methods/modules - importNodes.foreach(visitImport) - // Look at symbols with existing type info - prepopulateSymbolTable() - // Prune import names if the methods exist in the CPG - postVisitImports() - // Populate local symbol table with assignments - assignments.foreach(visitAssignments) - // See if any new information are in the parameters of methods - returns.foreach(visitReturns) - // Persist findings - setTypeInformation() - // Entrypoint for any final changes - postSetTypeInformation() - // Return number of changes - state.changesWereMade.get() - } finally { - symbolTable.clear() - } - - private def debugLocation(n: AstNode): String = { - val fileName = n.file.name.headOption.getOrElse("").stripPrefix(codeRoot) - val lineNo = n.lineNumber.getOrElse("") - s"$fileName#L$lineNo" - } - - /** Visits an import and stores references in the symbol table as both an identifier and call. - */ - protected def visitImport(i: Import): Unit = for { - resolvedImport <- i.call.tag - alias <- i.importedAs - } { - import io.appthreat.x2cpg.passes.frontend.ImportsPass.* - - ResolvedImport.tagToResolvedImport(resolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case ResolvedMember(basePath, memberName, _) => - val matchingIdentifiers = cpg.method.fullNameExact(basePath).local - val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member - val memberTypes = (matchingMembers ++ matchingIdentifiers) - .nameExact(memberName) - .getKnownTypes - symbolTable.append(LocalVar(alias), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case UnknownImport(path, _) => - symbolTable.append(CallAlias(alias), path) - symbolTable.append(LocalVar(alias), path) - } - } - - /** The initial import setting is over-approximated, so this step checks the CPG for any matches and prunes against - * these findings. If there are no findings, it will leave the table as is. The latter is significant for external - * types or methods. - */ - protected def postVisitImports(): Unit = {} - - /** Using assignment and import information (in the global symbol table), will propagate these types in the symbol - * table. - * - * @param a - * assignment call pointer. - */ - protected def visitAssignments(a: Assignment): Set[String] = { - a.argumentOut.l match { - case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) - case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) - case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) - case List(i: Identifier, l: Literal) if state.isFirstIteration => visitIdentifierAssignedToLiteral(i, l) - case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) - case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) - case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) - case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) - case List(c: Call, l: Literal) if state.isFirstIteration => visitCallAssignedToLiteral(c, l) - case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) - case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) - case _ => Set.empty - } - } - - /** Visits an identifier being assigned to the result of some operation. - */ - protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = { - val blockTypes = visitStatementsInBlock(b, Some(i)) - if (blockTypes.nonEmpty) associateTypes(i, blockTypes) - else Set.empty - } - - /** Visits a call operation being assigned to the result of some operation. - */ - protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = { - val blockTypes = visitStatementsInBlock(b) - assignTypesToCall(c, blockTypes) - } - - /** Process each statement but only assign the type of the last statement to the identifier - */ - protected def visitStatementsInBlock(b: Block, assignmentTarget: Option[Identifier] = None): Set[String] = - b.astChildren - .map { - case x: Call if x.name.startsWith(Operators.assignment) => visitAssignments(new Assignment(x)) - case x: Call if x.name.startsWith("") && assignmentTarget.isDefined => - visitIdentifierAssignedToOperator(assignmentTarget.get, x, x.name) - case x: Identifier if symbolTable.contains(x) => symbolTable.get(x) - case x: Call if symbolTable.contains(x) => symbolTable.get(x) - case x: Call if x.argument.headOption.exists(symbolTable.contains) => setCallMethodFullNameFromBase(x) - case x: Block => visitStatementsInBlock(x) - case x: Local => symbolTable.get(x) - case _: ControlStructure => Set.empty[String] - case x => logger.debug(s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}"); Set.empty[String] - } - .lastOption - .getOrElse(Set.empty[String]) - - /** Visits an identifier being assigned to a call. This call could be an operation, function invocation, or - * constructor invocation. - */ - protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { - if (c.name.startsWith("")) { - visitIdentifierAssignedToOperator(i, c, c.name) - } else if (symbolTable.contains(c) && isConstructor(c)) { - visitIdentifierAssignedToConstructor(i, c) - } else if (symbolTable.contains(c)) { - visitIdentifierAssignedToCallRetVal(i, c) - } else if (c.argument.headOption.exists(symbolTable.contains)) { - setCallMethodFullNameFromBase(c) - // Repeat this method now that the call has a type - visitIdentifierAssignedToCall(i, c) - } else { - // We can try obtain a return type for this call - visitIdentifierAssignedToCallRetVal(i, c) - } - } - - /** Visits an identifier being assigned to the value held by another identifier. This is a weak copy. - */ - protected def visitIdentifierAssignedToIdentifier(x: Identifier, y: Identifier): Set[String] = - if (symbolTable.contains(y)) associateTypes(x, symbolTable.get(y)) - else Set.empty - - /** Will build a call full path using the call base node. This method assumes the base node is in the symbol table. - */ - protected def setCallMethodFullNameFromBase(c: Call): Set[String] = { - val recTypes = c.argument.headOption - .map { - case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) - case x: Call => - cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot("ANY").typeFullName.toSet match { - case xs if xs.nonEmpty => xs - case _ => symbolTable.get(x).map(t => Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep.toString)) - } - case x => symbolTable.get(x) - } - .getOrElse(Set.empty[String]) - val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) - symbolTable.append(c, callTypes) - } - - /** A heuristic method to determine if a call is a constructor or not. - */ - protected def isConstructor(c: Call): Boolean - - /** A heuristic method to determine if a call name is a constructor or not. - */ - protected def isConstructor(name: String): Boolean - - /** A heuristic method to determine if an identifier may be a field or not. The result means that it would be stored - * in the global symbol table. By default this checks if the identifier name matches a member name. - * - * This has found to be an expensive operation accessed often so we have memoized this step. - */ - protected def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate(i.id(), i.method.typeDecl.member.nameExact(i.name).nonEmpty) - - /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field - * which this method uses [[isField]] to determine. - */ - protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = - symbolTable.append(i, types) - - /** Returns the appropriate field parent scope. - */ - protected def getFieldParents(fa: FieldAccess): Set[String] = { - val fieldName = getFieldName(fa).split(pathSep).last - cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet - } - - /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field - * which this method uses [[isField]] to determine. - */ - protected def associateTypes(symbol: LocalVar, fa: FieldAccess, types: Set[String]): Set[String] = { - fa.astChildren.filterNot(_.code.matches("(this|self)")).headOption.collect { - case fi: FieldIdentifier => - getFieldParents(fa).foreach(t => persistMemberWithTypeDecl(t, fi.canonicalName, types)) - case i: Identifier if isField(i) => - getFieldParents(fa).foreach(t => persistMemberWithTypeDecl(t, i.name, types)) - } - symbolTable.append(symbol, types) - } - - /** Similar to associateTypes but used in the case where there is some kind of field load. - */ - protected def associateInterproceduralTypes( - i: Identifier, - base: Identifier, - fi: FieldIdentifier, - fieldName: String, - baseTypes: Set[String] - ): Set[String] = { - val globalTypes = getFieldBaseType(base, fi) - associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes) - } - - protected def associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] - ): Set[String] = { - if (globalTypes.nonEmpty) { - // We have been able to resolve the type inter-procedurally - associateTypes(i, globalTypes) - } else if (baseTypes.nonEmpty) { - if (baseTypes.equals(symbolTable.get(LocalVar(fieldFullName)))) { - associateTypes(i, baseTypes) - } else { - // If not available, use a dummy variable that can be useful for call matching - associateTypes(i, baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep))) - } - } else { - // Assign dummy - val dummyTypes = Set( - XTypeRecovery.dummyMemberType(fieldFullName.stripSuffix(s"$pathSep$fieldName"), fieldName, pathSep) - ) - associateTypes(i, dummyTypes) - } - } - - /** Visits an identifier being assigned to an operator call. - */ - protected def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { - operation match { - case Operators.alloc => visitIdentifierAssignedToConstructor(i, c) - case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c)) - case Operators.indexAccess => visitIdentifierAssignedToIndexAccess(i, c) - case Operators.cast => visitIdentifierAssignedToCast(i, c) - case x => logger.debug(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty - } - } - - /** Visits an identifier being assigned to a constructor and attempts to speculate the constructor path. - */ - protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { - val constructorPaths = symbolTable.get(c).map(t => t.concat(s"$pathSep${Defines.ConstructorMethodName}")) - associateTypes(i, constructorPaths) - } - - /** Visits an identifier being assigned to a call's return value. - */ - protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = { - if (symbolTable.contains(c)) { - val callReturns = methodReturnValues(symbolTable.get(c).toSeq) - associateTypes(i, callReturns) - } else if (c.argument.exists(_.argumentIndex == 0)) { - val callFullNames = (c.argument(0) match { - case i: Identifier if symbolTable.contains(LocalVar(i.name)) => symbolTable.get(LocalVar(i.name)) - case i: Identifier if symbolTable.contains(CallAlias(i.name)) => symbolTable.get(CallAlias(i.name)) - case _ => Set.empty - }).map(_.concat(s"$pathSep${c.name}")).toSeq - val callReturns = methodReturnValues(callFullNames) - associateTypes(i, callReturns) - } else { - // Assign dummy value - associateTypes(i, Set(s"${c.name}$pathSep${XTypeRecovery.DummyReturnType}")) - } - } - - /** Will attempt to find the return values of a method if in the CPG, otherwise will give a dummy value. - */ - protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { - val rs = cpg.method - .fullNameExact(methodFullNames: _*) - .methodReturn - .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) - .filterNot(_.equals("ANY")) - .toSet - if (rs.isEmpty) methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet - else rs - } - - /** Will handle literal value assignments. Override if special handling is required. - */ - protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] = - associateTypes(i, getLiteralType(l)) - - /** Not all frontends populate typeFullName for literals so we allow this to be overridden. - */ - protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName) - - /** Will handle an identifier holding a function pointer. - */ - protected def visitIdentifierAssignedToMethodRef( - i: Identifier, - m: MethodRef, - rec: Option[String] = None - ): Set[String] = - symbolTable.append(CallAlias(i.name, rec), Set(m.methodFullName)) - - /** Will handle an identifier holding a type pointer. - */ - protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef, rec: Option[String] = None): Set[String] = - symbolTable.append(CallAlias(i.name, rec), Set(t.typeFullName)) - - /** Visits a call assigned to an identifier. This is often when there are operators involved. - */ - protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = { - val rhsTypes = symbolTable.get(i) - assignTypesToCall(c, rhsTypes) - } - - /** Visits a call assigned to the return value of a call. This is often when there are operators involved. - */ - protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = - assignTypesToCall(x, getTypesFromCall(y)) - - /** Given a call operation, will attempt to retrieve types from it. - */ - protected def getTypesFromCall(c: Call): Set[String] = c.name match { - case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) - case _ if symbolTable.contains(c) => symbolTable.get(c) - case Operators.indexAccess => getIndexAccessTypes(c) - case n => - logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") - Set.empty[String] - } - - /** Given a LHS call, will retrieve its symbol to the given types. - */ - protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = { - if (types.nonEmpty) { - getSymbolFromCall(x) match { - case (lhs, globalKeys) if globalKeys.nonEmpty => - globalKeys.foreach { (fieldVar: FieldPath) => - persistMemberWithTypeDecl(fieldVar.compUnitFullName, fieldVar.identifier, types) - } - symbolTable.append(lhs, types) - case (lhs, _) => symbolTable.append(lhs, types) - } - } else Set.empty - } - - /** Will attempt to retrieve index access types otherwise will return dummy value. - */ - protected def getIndexAccessTypes(ia: Call): Set[String] = indexAccessToCollectionVar(ia) match { - case Some(cVar) if symbolTable.contains(cVar) => - symbolTable.get(cVar) - case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) => - symbolTable.get(LocalVar(cVar.identifier)).map(x => s"$x$pathSep${XTypeRecovery.DummyIndexAccess}") - case _ => Set.empty - } - - /** Convenience class for transporting field names. - * @param compUnitFullName - * qualified path to base type holding the member. - * @param identifier - * the member name. - */ - case class FieldPath(compUnitFullName: String, identifier: String) - - /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of an assignment. The - * second element is a list of any associated global keys if applicable. - */ - protected def getSymbolFromCall(c: Call): (LocalKey, Set[FieldPath]) = c.name match { - case Operators.fieldAccess => - val fa = new FieldAccess(c) - val fieldName = getFieldName(fa) - (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldPath(fp, fieldName))) - case Operators.indexAccess => (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty) - case x => - logger.debug(s"Using default LHS call name '$x' @ ${debugLocation(c)}") - (LocalVar(c.name), Set.empty) - } - - /** Extracts a string representation of the name of the field within this field access. - */ - protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = { - def wrapName(n: String) = { - val sb = new mutable.StringBuilder() - if (prefix.nonEmpty) sb.append(s"$prefix$pathSep") - sb.append(n) - if (suffix.nonEmpty) sb.append(s"$pathSep$suffix") - sb.toString() - } - - lazy val typesFromBaseCall = fa.argumentOut.headOption match - case Some(call: Call) => getTypesFromCall(call) - case _ => Set.empty[String] - - fa.argumentOut.l match { - case ::(i: Identifier, ::(f: FieldIdentifier, _)) if i.name.matches("(self|this)") => wrapName(f.canonicalName) - case ::(i: Identifier, ::(f: FieldIdentifier, _)) => wrapName(s"${i.name}$pathSep${f.canonicalName}") - case ::(c: Call, ::(f: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => - wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName)) - case ::(_: Call, ::(f: FieldIdentifier, _)) if typesFromBaseCall.nonEmpty => - // TODO: Handle this case better - wrapName(s"${typesFromBaseCall.head}$pathSep${f.canonicalName}") - case ::(f: FieldIdentifier, ::(c: Call, _)) if c.name.equals(Operators.fieldAccess) => - wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName)) - case ::(c: Call, ::(f: FieldIdentifier, _)) => - // TODO: Handle this case better - val callCode = if (c.code.contains("(")) c.code.substring(c.code.indexOf("(")) else c.code - XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) - case xs => - logger.debug( - s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}" - ) - wrapName("") - } - } - - protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = { - if (c.name.equals(Operators.indexAccess)) { - // For now, we will just handle this on a very basic level - c.argumentOut.l match { - case List(_: Identifier, _: Literal) => - indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, getLiteralType(l))).getOrElse(Set.empty) - case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) => - // Imprecise but sound! - indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, symbolTable.get(idx))).getOrElse(Set.empty) - case List(i: Identifier, c: Call) => - // This is an expensive level of precision to support - symbolTable.append(CollectionVar(i.name, "*"), getTypesFromCall(c)) - case List(c: Call, l: Literal) => assignTypesToCall(c, getLiteralType(l)) - case xs => - logger.debug( - s"Unhandled index access point assigned to literal ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}" - ) - Set.empty - } - } else if (c.name.equals(Operators.fieldAccess)) { - val fa = new FieldAccess(c) - val fieldName = getFieldName(fa) - associateTypes(LocalVar(fieldName), fa, getLiteralType(l)) - } else { - logger.debug(s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}") - Set.empty - } - } - - /** Handles a call operation assigned to a method/function pointer. - */ - protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] = - assignTypesToCall(c, Set(m.methodFullName)) - - /** Generates an identifier for collection/index-access operations in the symbol table. - */ - protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = { - def callName(x: Call) = - if (x.name.equals(Operators.fieldAccess)) - getFieldName(new FieldAccess(x)) - else if (x.name.equals(Operators.indexAccess)) - indexAccessToCollectionVar(x) - .map(cv => s"${cv.identifier}[${cv.idx}]") - .getOrElse(XTypeRecovery.DummyIndexAccess) - else x.name - - Option(c.argumentOut.l match { - case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) - case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) - case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) - case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code) - case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code) - case xs => - logger.debug(s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}") - null - }) - } - - /** Will handle an identifier being assigned to a field value. - */ - protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = { - val fieldName = getFieldName(fa) - fa.argumentOut.l match { - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) if symbolTable.contains(LocalVar(base.name)) => - // Get field from global table if referenced as a variable - val localTypes = symbolTable.get(LocalVar(base.name)) - associateInterproceduralTypes(i, base, fi, fieldName, localTypes) - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) if symbolTable.contains(LocalVar(fieldName)) => - val localTypes = symbolTable.get(LocalVar(fieldName)) - associateInterproceduralTypes(i, base, fi, fieldName, localTypes) - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) => - val dummyTypes = Set(s"$fieldName$pathSep${XTypeRecovery.DummyReturnType}") - associateInterproceduralTypes(i, base, fi, fieldName, dummyTypes) - case ::(c: Call, ::(fi: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => - val baseName = getFieldName(new FieldAccess(c)) - // Build type regardless of length - // TODO: This is more prone to giving dummy values as it does not do global look-ups - // but this is okay for now - val buf = mutable.ArrayBuffer.empty[String] - for (segment <- baseName.split(pathSep) ++ Array(fi.canonicalName)) { - val types = - if (buf.isEmpty) symbolTable.get(LocalVar(segment)) - else buf.flatMap(t => symbolTable.get(LocalVar(s"$t$pathSep$segment"))).toSet - if (types.nonEmpty) { - buf.clear() - buf.addAll(types) - } else { - val bufCopy = Array.from(buf) - buf.clear() - bufCopy.foreach { - case t if isConstructor(segment) => buf.addOne(s"$t$pathSep$segment") - case t => buf.addOne(XTypeRecovery.dummyMemberType(t, segment, pathSep)) - } - } +) extends RecursiveTask[Boolean]: + + protected val logger: Logger = LoggerFactory.getLogger(getClass) + + /** Stores type information for local structures that live within this compilation unit, e.g. + * local variables. + */ + protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) + + /** The root of the target codebase. + */ + protected val codeRoot: String = + cpg.metaData.root.headOption.getOrElse("") + java.io.File.separator + + /** The delimiter used to separate methods/functions in qualified names. + */ + protected val pathSep = '.' + + /** New node tracking set. + */ + protected val addedNodes = mutable.HashSet.empty[String] + + /** For tracking members and the type operations that need to be performed. Since these are + * mostly out of scope locally it helps to track these separately. + * + * // TODO: Potentially a new use for a global table or modification to the symbol table? + */ + protected val newTypesForMembers = mutable.HashMap.empty[Member, Set[String]] + + /** Provides an entrypoint to add known symbols and their possible types. + */ + protected def prepopulateSymbolTable(): Unit = + (cu.ast.isIdentifier ++ cu.ast.isCall ++ cu.ast.isLocal ++ cu.ast.isParameter) + .filter(hasTypes) + .foreach(prepopulateSymbolTableEntry) + + protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match + case x @ (_: Identifier | _: Local | _: MethodParameterIn) => + symbolTable.append(x, x.getKnownTypes) + case x: Call => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case _ => + + protected def hasTypes(node: AstNode): Boolean = node match + case x: Call if !x.methodFullName.startsWith("") => + !x.methodFullName.toLowerCase().matches("(|any)") + case x => x.getKnownTypes.nonEmpty + + protected def assignments: Iterator[Assignment] = cu match + case x: File => + x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( + new OpNodes.Assignment(_) + ) + case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( + new OpNodes.Assignment(_) + ) + case x => x.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) + + protected def members: Iterator[Member] = cu.ast.isMember + + protected def returns: Iterator[Return] = cu match + case x: File => x.method.flatMap(_._returnViaContainsOut) + case x: Method => x._returnViaContainsOut + case x => x.ast.isReturn + + protected def importNodes: Iterator[Import] = cu match + case x: File => x.method.flatMap(_._callViaContainsOut).referencedImports + case x: Method => x.file.method.flatMap(_._callViaContainsOut).referencedImports + case x => x.ast.isCall.referencedImports + + override def compute(): Boolean = + try + // Set known aliases that point to imports for local and external methods/modules + importNodes.foreach(visitImport) + // Look at symbols with existing type info + prepopulateSymbolTable() + // Prune import names if the methods exist in the CPG + postVisitImports() + // Populate local symbol table with assignments + assignments.foreach(visitAssignments) + // See if any new information are in the parameters of methods + returns.foreach(visitReturns) + // Persist findings + setTypeInformation() + // Entrypoint for any final changes + postSetTypeInformation() + // Return number of changes + state.changesWereMade.get() + finally + symbolTable.clear() + + private def debugLocation(n: AstNode): String = + val fileName = n.file.name.headOption.getOrElse("").stripPrefix(codeRoot) + val lineNo = n.lineNumber.getOrElse("") + s"$fileName#L$lineNo" + + /** Visits an import and stores references in the symbol table as both an identifier and call. + */ + protected def visitImport(i: Import): Unit = for + resolvedImport <- i.call.tag + alias <- i.importedAs + do + import io.appthreat.x2cpg.passes.frontend.ImportsPass.* + + ResolvedImport.tagToResolvedImport(resolvedImport).foreach { + case ResolvedMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case ResolvedTypeDecl(fullName, _) => + symbolTable.append(LocalVar(alias), fullName) + case ResolvedMember(basePath, memberName, _) => + val matchingIdentifiers = cpg.method.fullNameExact(basePath).local + val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member + val memberTypes = (matchingMembers ++ matchingIdentifiers) + .nameExact(memberName) + .getKnownTypes + symbolTable.append(LocalVar(alias), memberTypes) + case UnknownMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case UnknownTypeDecl(fullName, _) => + symbolTable.append(LocalVar(alias), fullName) + case UnknownImport(path, _) => + symbolTable.append(CallAlias(alias), path) + symbolTable.append(LocalVar(alias), path) } - associateTypes(i, buf.toSet) - case ::(call: Call, ::(fi: FieldIdentifier, _)) => - assignTypesToCall( - call, - Set(fieldName.stripSuffix(s"${XTypeRecovery.DummyMemberLoad}$pathSep${fi.canonicalName}")) + + /** The initial import setting is over-approximated, so this step checks the CPG for any matches + * and prunes against these findings. If there are no findings, it will leave the table as is. + * The latter is significant for external types or methods. + */ + protected def postVisitImports(): Unit = {} + + /** Using assignment and import information (in the global symbol table), will propagate these + * types in the symbol table. + * + * @param a + * assignment call pointer. + */ + protected def visitAssignments(a: Assignment): Set[String] = + a.argumentOut.l match + case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) + case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) + case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) + case List(i: Identifier, l: Literal) if state.isFirstIteration => + visitIdentifierAssignedToLiteral(i, l) + case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) + case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) + case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) + case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) + case List(c: Call, l: Literal) if state.isFirstIteration => + visitCallAssignedToLiteral(c, l) + case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) + case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) + case _ => Set.empty + + /** Visits an identifier being assigned to the result of some operation. + */ + protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = + val blockTypes = visitStatementsInBlock(b, Some(i)) + if blockTypes.nonEmpty then associateTypes(i, blockTypes) + else Set.empty + + /** Visits a call operation being assigned to the result of some operation. + */ + protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = + val blockTypes = visitStatementsInBlock(b) + assignTypesToCall(c, blockTypes) + + /** Process each statement but only assign the type of the last statement to the identifier + */ + protected def visitStatementsInBlock( + b: Block, + assignmentTarget: Option[Identifier] = None + ): Set[String] = + b.astChildren + .map { + case x: Call if x.name.startsWith(Operators.assignment) => + visitAssignments(new Assignment(x)) + case x: Call if x.name.startsWith("") && assignmentTarget.isDefined => + visitIdentifierAssignedToOperator(assignmentTarget.get, x, x.name) + case x: Identifier if symbolTable.contains(x) => symbolTable.get(x) + case x: Call if symbolTable.contains(x) => symbolTable.get(x) + case x: Call if x.argument.headOption.exists(symbolTable.contains) => + setCallMethodFullNameFromBase(x) + case x: Block => visitStatementsInBlock(x) + case x: Local => symbolTable.get(x) + case _: ControlStructure => Set.empty[String] + case x => + logger.debug( + s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}" + ); Set.empty[String] + } + .lastOption + .getOrElse(Set.empty[String]) + + /** Visits an identifier being assigned to a call. This call could be an operation, function + * invocation, or constructor invocation. + */ + protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + if c.name.startsWith("") then + visitIdentifierAssignedToOperator(i, c, c.name) + else if symbolTable.contains(c) && isConstructor(c) then + visitIdentifierAssignedToConstructor(i, c) + else if symbolTable.contains(c) then + visitIdentifierAssignedToCallRetVal(i, c) + else if c.argument.headOption.exists(symbolTable.contains) then + setCallMethodFullNameFromBase(c) + // Repeat this method now that the call has a type + visitIdentifierAssignedToCall(i, c) + else + // We can try obtain a return type for this call + visitIdentifierAssignedToCallRetVal(i, c) + + /** Visits an identifier being assigned to the value held by another identifier. This is a weak + * copy. + */ + protected def visitIdentifierAssignedToIdentifier(x: Identifier, y: Identifier): Set[String] = + if symbolTable.contains(y) then associateTypes(x, symbolTable.get(y)) + else Set.empty + + /** Will build a call full path using the call base node. This method assumes the base node is + * in the symbol table. + */ + protected def setCallMethodFullNameFromBase(c: Call): Set[String] = + val recTypes = c.argument.headOption + .map { + case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) + case x: Call => + cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot( + "ANY" + ).typeFullName.toSet match + case xs if xs.nonEmpty => xs + case _ => symbolTable.get(x).map(t => + Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) + ) + case x => symbolTable.get(x) + } + .getOrElse(Set.empty[String]) + val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) + symbolTable.append(c, callTypes) + + /** A heuristic method to determine if a call is a constructor or not. + */ + protected def isConstructor(c: Call): Boolean + + /** A heuristic method to determine if a call name is a constructor or not. + */ + protected def isConstructor(name: String): Boolean + + /** A heuristic method to determine if an identifier may be a field or not. The result means + * that it would be stored in the global symbol table. By default this checks if the identifier + * name matches a member name. + * + * This has found to be an expensive operation accessed often so we have memoized this step. + */ + protected def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + i.method.typeDecl.member.nameExact(i.name).nonEmpty ) - case _ => - logger.debug(s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}") - Set.empty - } - } - - /** Visits an identifier being assigned to the result of an index access operation. - */ - protected def visitIdentifierAssignedToIndexAccess(i: Identifier, c: Call): Set[String] = - associateTypes(i, getTypesFromCall(c)) - - /** Visits an identifier that is the target of a cast operation. - */ - protected def visitIdentifierAssignedToCast(i: Identifier, c: Call): Set[String] = - associateTypes(i, (c.typeFullName +: c.dynamicTypeHintFullName).filterNot(_ == "ANY").toSet) - - protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = - getFieldBaseType(base.name, fi.canonicalName) - - protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = - symbolTable - .get(LocalVar(baseName)) - .flatMap(t => typeDeclIterator(t).member.nameExact(fieldName)) - .typeFullNameNot("ANY") - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .toSet - - protected def visitReturns(ret: Return): Unit = { - val m = ret.method - val existingTypes = mutable.HashSet.from( - (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) - .filterNot(_ == "ANY") - ) - @tailrec - def extractTypes(xs: List[CfgNode]): Set[String] = xs match { - case ::(head: Literal, Nil) if head.typeFullName != "ANY" => - Set(head.typeFullName) - case ::(head: Call, Nil) if head.name == Operators.fieldAccess => - val fieldAccess = new FieldAccess(head) - val (sym, ts) = getSymbolFromCall(fieldAccess) - val cpgTypes = cpg.typeDecl - .fullNameExact(ts.map(_.compUnitFullName).toSeq: _*) - .member - .nameExact(sym.identifier) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot { x => x == "ANY" || x == "this" } - .toSet - if (cpgTypes.nonEmpty) cpgTypes - else symbolTable.get(sym) - case ::(head: Call, Nil) if symbolTable.contains(head) => - val callPaths = symbolTable.get(head) - val returnValues = methodReturnValues(callPaths.toSeq) - if (returnValues.isEmpty) - callPaths.map(c => s"$c$pathSep${XTypeRecovery.DummyReturnType}") + + /** Associates the types with the identifier. This may sometimes be an identifier that should be + * considered a field which this method uses [[isField]] to determine. + */ + protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = + symbolTable.append(i, types) + + /** Returns the appropriate field parent scope. + */ + protected def getFieldParents(fa: FieldAccess): Set[String] = + val fieldName = getFieldName(fa).split(pathSep).last + cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet + + /** Associates the types with the identifier. This may sometimes be an identifier that should be + * considered a field which this method uses [[isField]] to determine. + */ + protected def associateTypes( + symbol: LocalVar, + fa: FieldAccess, + types: Set[String] + ): Set[String] = + fa.astChildren.filterNot(_.code.matches("(this|self)")).headOption.collect { + case fi: FieldIdentifier => + getFieldParents(fa).foreach(t => + persistMemberWithTypeDecl(t, fi.canonicalName, types) + ) + case i: Identifier if isField(i) => + getFieldParents(fa).foreach(t => persistMemberWithTypeDecl(t, i.name, types)) + } + symbolTable.append(symbol, types) + + /** Similar to associateTypes but used in the case where there is some kind of field load. + */ + protected def associateInterproceduralTypes( + i: Identifier, + base: Identifier, + fi: FieldIdentifier, + fieldName: String, + baseTypes: Set[String] + ): Set[String] = + val globalTypes = getFieldBaseType(base, fi) + associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes) + + protected def associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] + ): Set[String] = + if globalTypes.nonEmpty then + // We have been able to resolve the type inter-procedurally + associateTypes(i, globalTypes) + else if baseTypes.nonEmpty then + if baseTypes.equals(symbolTable.get(LocalVar(fieldFullName))) then + associateTypes(i, baseTypes) + else + // If not available, use a dummy variable that can be useful for call matching + associateTypes( + i, + baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep)) + ) + else + // Assign dummy + val dummyTypes = Set( + XTypeRecovery.dummyMemberType( + fieldFullName.stripSuffix(s"$pathSep$fieldName"), + fieldName, + pathSep + ) + ) + associateTypes(i, dummyTypes) + + /** Visits an identifier being assigned to an operator call. + */ + protected def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case Operators.alloc => visitIdentifierAssignedToConstructor(i, c) + case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c)) + case Operators.indexAccess => visitIdentifierAssignedToIndexAccess(i, c) + case Operators.cast => visitIdentifierAssignedToCast(i, c) + case x => + logger.debug(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty + + /** Visits an identifier being assigned to a constructor and attempts to speculate the + * constructor path. + */ + protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = + val constructorPaths = + symbolTable.get(c).map(t => t.concat(s"$pathSep${Defines.ConstructorMethodName}")) + associateTypes(i, constructorPaths) + + /** Visits an identifier being assigned to a call's return value. + */ + protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = + if symbolTable.contains(c) then + val callReturns = methodReturnValues(symbolTable.get(c).toSeq) + associateTypes(i, callReturns) + else if c.argument.exists(_.argumentIndex == 0) then + val callFullNames = (c.argument(0) match + case i: Identifier if symbolTable.contains(LocalVar(i.name)) => + symbolTable.get(LocalVar(i.name)) + case i: Identifier if symbolTable.contains(CallAlias(i.name)) => + symbolTable.get(CallAlias(i.name)) + case _ => Set.empty + ).map(_.concat(s"$pathSep${c.name}")).toSeq + val callReturns = methodReturnValues(callFullNames) + associateTypes(i, callReturns) + else + // Assign dummy value + associateTypes(i, Set(s"${c.name}$pathSep${XTypeRecovery.DummyReturnType}")) + + /** Will attempt to find the return values of a method if in the CPG, otherwise will give a + * dummy value. + */ + protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = + val rs = cpg.method + .fullNameExact(methodFullNames*) + .methodReturn + .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) + .filterNot(_.equals("ANY")) + .toSet + if rs.isEmpty then + methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet + else rs + + /** Will handle literal value assignments. Override if special handling is required. + */ + protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] = + associateTypes(i, getLiteralType(l)) + + /** Not all frontends populate typeFullName for literals so we allow this to be + * overridden. + */ + protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName) + + /** Will handle an identifier holding a function pointer. + */ + protected def visitIdentifierAssignedToMethodRef( + i: Identifier, + m: MethodRef, + rec: Option[String] = None + ): Set[String] = + symbolTable.append(CallAlias(i.name, rec), Set(m.methodFullName)) + + /** Will handle an identifier holding a type pointer. + */ + protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] = None + ): Set[String] = + symbolTable.append(CallAlias(i.name, rec), Set(t.typeFullName)) + + /** Visits a call assigned to an identifier. This is often when there are operators involved. + */ + protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = + val rhsTypes = symbolTable.get(i) + assignTypesToCall(c, rhsTypes) + + /** Visits a call assigned to the return value of a call. This is often when there are operators + * involved. + */ + protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = + assignTypesToCall(x, getTypesFromCall(y)) + + /** Given a call operation, will attempt to retrieve types from it. + */ + protected def getTypesFromCall(c: Call): Set[String] = c.name match + case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) + case _ if symbolTable.contains(c) => symbolTable.get(c) + case Operators.indexAccess => getIndexAccessTypes(c) + case n => + logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") + Set.empty[String] + + /** Given a LHS call, will retrieve its symbol to the given types. + */ + protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = + if types.nonEmpty then + getSymbolFromCall(x) match + case (lhs, globalKeys) if globalKeys.nonEmpty => + globalKeys.foreach { (fieldVar: FieldPath) => + persistMemberWithTypeDecl( + fieldVar.compUnitFullName, + fieldVar.identifier, + types + ) + } + symbolTable.append(lhs, types) + case (lhs, _) => symbolTable.append(lhs, types) + else Set.empty + + /** Will attempt to retrieve index access types otherwise will return dummy value. + */ + protected def getIndexAccessTypes(ia: Call): Set[String] = indexAccessToCollectionVar(ia) match + case Some(cVar) if symbolTable.contains(cVar) => + symbolTable.get(cVar) + case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) => + symbolTable.get(LocalVar(cVar.identifier)).map(x => + s"$x$pathSep${XTypeRecovery.DummyIndexAccess}" + ) + case _ => Set.empty + + /** Convenience class for transporting field names. + * @param compUnitFullName + * qualified path to base type holding the member. + * @param identifier + * the member name. + */ + case class FieldPath(compUnitFullName: String, identifier: String) + + /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of + * an assignment. The second element is a list of any associated global keys if applicable. + */ + protected def getSymbolFromCall(c: Call): (LocalKey, Set[FieldPath]) = c.name match + case Operators.fieldAccess => + val fa = new FieldAccess(c) + val fieldName = getFieldName(fa) + (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldPath(fp, fieldName))) + case Operators.indexAccess => + (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty) + case x => + logger.debug(s"Using default LHS call name '$x' @ ${debugLocation(c)}") + (LocalVar(c.name), Set.empty) + + /** Extracts a string representation of the name of the field within this field access. + */ + protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = + def wrapName(n: String) = + val sb = new mutable.StringBuilder() + if prefix.nonEmpty then sb.append(s"$prefix$pathSep") + sb.append(n) + if suffix.nonEmpty then sb.append(s"$pathSep$suffix") + sb.toString() + + lazy val typesFromBaseCall = fa.argumentOut.headOption match + case Some(call: Call) => getTypesFromCall(call) + case _ => Set.empty[String] + + fa.argumentOut.l match + case ::(i: Identifier, ::(f: FieldIdentifier, _)) if i.name.matches("(self|this)") => + wrapName(f.canonicalName) + case ::(i: Identifier, ::(f: FieldIdentifier, _)) => + wrapName(s"${i.name}$pathSep${f.canonicalName}") + case ::(c: Call, ::(f: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName)) + case ::(_: Call, ::(f: FieldIdentifier, _)) if typesFromBaseCall.nonEmpty => + // TODO: Handle this case better + wrapName(s"${typesFromBaseCall.head}$pathSep${f.canonicalName}") + case ::(f: FieldIdentifier, ::(c: Call, _)) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName)) + case ::(c: Call, ::(f: FieldIdentifier, _)) => + // TODO: Handle this case better + val callCode = + if c.code.contains("(") then c.code.substring(c.code.indexOf("(")) else c.code + XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) + case xs => + logger.debug( + s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}" + ) + wrapName("") + end match + end getFieldName + + protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = + if c.name.equals(Operators.indexAccess) then + // For now, we will just handle this on a very basic level + c.argumentOut.l match + case List(_: Identifier, _: Literal) => + indexAccessToCollectionVar(c).map(cv => + symbolTable.append(cv, getLiteralType(l)) + ).getOrElse(Set.empty) + case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) => + // Imprecise but sound! + indexAccessToCollectionVar(c).map(cv => + symbolTable.append(cv, symbolTable.get(idx)) + ).getOrElse(Set.empty) + case List(i: Identifier, c: Call) => + // This is an expensive level of precision to support + symbolTable.append(CollectionVar(i.name, "*"), getTypesFromCall(c)) + case List(c: Call, l: Literal) => assignTypesToCall(c, getLiteralType(l)) + case xs => + logger.debug( + s"Unhandled index access point assigned to literal ${xs.map(x => + (x.label, x.code) + ).mkString(",")} @ ${debugLocation(c)}" + ) + Set.empty + else if c.name.equals(Operators.fieldAccess) then + val fa = new FieldAccess(c) + val fieldName = getFieldName(fa) + associateTypes(LocalVar(fieldName), fa, getLiteralType(l)) else - returnValues - case ::(head: Call, Nil) if head.argumentOut.headOption.exists(symbolTable.contains) => + logger.debug( + s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}" + ) + Set.empty + + /** Handles a call operation assigned to a method/function pointer. + */ + protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] = + assignTypesToCall(c, Set(m.methodFullName)) + + /** Generates an identifier for collection/index-access operations in the symbol table. + */ + protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = + def callName(x: Call) = + if x.name.equals(Operators.fieldAccess) then + getFieldName(new FieldAccess(x)) + else if x.name.equals(Operators.indexAccess) then + indexAccessToCollectionVar(x) + .map(cv => s"${cv.identifier}[${cv.idx}]") + .getOrElse(XTypeRecovery.DummyIndexAccess) + else x.name + + Option(c.argumentOut.l match + case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) + case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) + case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) + case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code) + case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code) + case xs => + logger.debug( + s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}" + ) + null + ) + end indexAccessToCollectionVar + + /** Will handle an identifier being assigned to a field value. + */ + protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = + val fieldName = getFieldName(fa) + fa.argumentOut.l match + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) + if symbolTable.contains(LocalVar(base.name)) => + // Get field from global table if referenced as a variable + val localTypes = symbolTable.get(LocalVar(base.name)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) + if symbolTable.contains(LocalVar(fieldName)) => + val localTypes = symbolTable.get(LocalVar(fieldName)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) => + val dummyTypes = Set(s"$fieldName$pathSep${XTypeRecovery.DummyReturnType}") + associateInterproceduralTypes(i, base, fi, fieldName, dummyTypes) + case ::(c: Call, ::(fi: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => + val baseName = getFieldName(new FieldAccess(c)) + // Build type regardless of length + // TODO: This is more prone to giving dummy values as it does not do global look-ups + // but this is okay for now + val buf = mutable.ArrayBuffer.empty[String] + for segment <- baseName.split(pathSep) ++ Array(fi.canonicalName) do + val types = + if buf.isEmpty then symbolTable.get(LocalVar(segment)) + else + buf.flatMap(t => symbolTable.get(LocalVar(s"$t$pathSep$segment"))).toSet + if types.nonEmpty then + buf.clear() + buf.addAll(types) + else + val bufCopy = Array.from(buf) + buf.clear() + bufCopy.foreach { + case t if isConstructor(segment) => buf.addOne(s"$t$pathSep$segment") + case t => buf.addOne(XTypeRecovery.dummyMemberType(t, segment, pathSep)) + } + associateTypes(i, buf.toSet) + case ::(call: Call, ::(fi: FieldIdentifier, _)) => + assignTypesToCall( + call, + Set(fieldName.stripSuffix( + s"${XTypeRecovery.DummyMemberLoad}$pathSep${fi.canonicalName}" + )) + ) + case _ => + logger.debug( + s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}" + ) + Set.empty + end match + end visitIdentifierAssignedToFieldLoad + + /** Visits an identifier being assigned to the result of an index access operation. + */ + protected def visitIdentifierAssignedToIndexAccess(i: Identifier, c: Call): Set[String] = + associateTypes(i, getTypesFromCall(c)) + + /** Visits an identifier that is the target of a cast operation. + */ + protected def visitIdentifierAssignedToCast(i: Identifier, c: Call): Set[String] = + associateTypes(i, (c.typeFullName +: c.dynamicTypeHintFullName).filterNot(_ == "ANY").toSet) + + protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = + getFieldBaseType(base.name, fi.canonicalName) + + protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = symbolTable - .get(head.argumentOut.head) - .map(t => Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep.toString)) - case ::(identifier: Identifier, Nil) if symbolTable.contains(identifier) => - symbolTable.get(identifier) - case ::(head: Call, Nil) => - extractTypes(head.argument.l) - case _ => Set.empty - } - val returnTypes = extractTypes(ret.argumentOut.l) - existingTypes.addAll(returnTypes) - builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes) - } - - /** Using an entry from the symbol table, will queue the CPG modification to persist the recovered type information. - */ - protected def setTypeInformation(): Unit = { - cu.ast - .collect { - case n: Local => n - case n: Call => n - case n: Expression => n - case n: MethodParameterIn if state.isFinalIteration => n - case n: MethodReturn if state.isFinalIteration => n - } - .foreach { - case x: Local if symbolTable.contains(x) => storeNodeTypeInfo(x, symbolTable.get(x).toSeq) - case x: MethodParameterIn => setTypeFromTypeHints(x) - case x: MethodReturn => setTypeFromTypeHints(x) - case x: Identifier if symbolTable.contains(x) => - setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) - case x: Call if symbolTable.contains(x) => - val typs = - if (state.config.enabledDummyTypes) symbolTable.get(x).toSeq - else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq - storeCallTypeInfo(x, typs) - case x: Identifier if symbolTable.contains(CallAlias(x.name)) && x.inCall.nonEmpty => - setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) - case x: Call if x.argument.headOption.exists(symbolTable.contains) => - setTypeInformationForRecCall(x, Option(x), x.argument.l) - case _ => - } - // Set types in an atomic way - newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } - } - - protected def createCallFromIdentifierTypeFullName(typeFullName: String, callName: String): String = - s"$typeFullName$pathSep$callName" - - /** Sets type information for a receiver/call pattern. - */ - private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = { - (n, ms) match { - // Case 1: 'call' is an assignment from some dynamic dispatch call - case (Some(call: Call), ::(i: Identifier, ::(c: Call, _))) if call.name == Operators.assignment => - setTypeForIdentifierAssignedToCall(call, i, c) - // Case 1: 'call' is an assignment from some other data structure - case (Some(call: Call), ::(i: Identifier, _)) if call.name == Operators.assignment => - setTypeForIdentifierAssignedToDefault(call, i) - // Case 2: 'i' is the receiver of 'call' - case (Some(call: Call), ::(i: Identifier, _)) if call.name != Operators.fieldAccess => - setTypeForDynamicDispatchCall(call, i) - // Case 3: 'i' is the receiver for a field access on member 'f' - case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _))) - if fieldAccess.name == Operators.fieldAccess => - setTypeForFieldAccess(new FieldAccess(fieldAccess), i, f) - case _ => - } - // Handle the node itself - x match { - case c: Call if c.name.startsWith(" - case _ => persistType(x, symbolTable.get(x)) - } - } - - protected def setTypeForFieldAccess(fieldAccess: Call, i: Identifier, f: FieldIdentifier): Unit = { - val idHints = if (symbolTable.contains(i)) symbolTable.get(i) else symbolTable.get(CallAlias(i.name)) - val callTypes = symbolTable.get(fieldAccess) - persistType(i, idHints) - persistType(fieldAccess, callTypes) - fieldAccess.astParent.iterator.isCall.headOption match { - case Some(callFromFieldName) if symbolTable.contains(callFromFieldName) => - persistType(callFromFieldName, symbolTable.get(callFromFieldName)) - case _ => - } - // This field may be a function pointer - handlePotentialFunctionPointer(fieldAccess, idHints, f.canonicalName, Option(i.name)) - } - - protected def setTypeForDynamicDispatchCall(call: Call, i: Identifier): Unit = { - val idHints = symbolTable.get(i) - val callTypes = symbolTable.get(call) - persistType(i, idHints) - if (callTypes.isEmpty && !call.name.startsWith("")) - // For now, calls are treated as function pointers and thus the type should point to the method - persistType(call, idHints.map(t => createCallFromIdentifierTypeFullName(t, call.name))) - else { - persistType(call, callTypes) - } - } - - protected def setTypeForIdentifierAssignedToDefault(call: Call, i: Identifier): Unit = { - val idHints = symbolTable.get(i) - persistType(i, idHints) - persistType(call, idHints) - } - - protected def setTypeForIdentifierAssignedToCall(call: Call, i: Identifier, c: Call): Unit = { - val idTypes = if (symbolTable.contains(i)) symbolTable.get(i) else symbolTable.get(CallAlias(i.name)) - val callTypes = symbolTable.get(c) - persistType(call, callTypes) - if (idTypes.nonEmpty || callTypes.nonEmpty) { - if (idTypes.equals(callTypes)) - // Case 1.1: This is a function pointer or constructor - persistType(i, callTypes) - else - // Case 1.2: This is the return value of the function - persistType(i, idTypes) - } - } - - protected def setTypeFromTypeHints(n: StoredNode): Unit = { - val types = n.getKnownTypes.filterNot(XTypeRecovery.isDummyType) - if (types.nonEmpty) setTypes(n, types.toSeq) - } - - /** In the case this field access is a function pointer, we would want to make sure this has a method ref. - */ - private def handlePotentialFunctionPointer( - funcPtr: Expression, - baseTypes: Set[String], - funcName: String, - baseName: Option[String] = None - ): Unit = { - // Sometimes the function identifier is an argument to the call itself as a "base". In this case we don't need - // a method ref. This happens in jssrc2cpg - if (!funcPtr.astParent.iterator.collectAll[Call].exists(_.name == funcName)) { - baseTypes - .map(t => if (t.endsWith(funcName)) t else s"$t$pathSep$funcName") - .flatMap(cpg.method.fullNameExact) - .filterNot(m => addedNodes.contains(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${m.fullName}")) - .map(m => m -> createMethodRef(baseName, funcName, m.fullName, funcPtr.lineNumber, funcPtr.columnNumber)) - .foreach { case (m, mRef) => - funcPtr.astParent - .filterNot(_.astChildren.isMethodRef.exists(_.methodFullName == mRef.methodFullName)) - .foreach { inCall => - state.changesWereMade.compareAndSet(false, true) - integrateMethodRef(funcPtr, m, mRef, inCall) + .get(LocalVar(baseName)) + .flatMap(t => typeDeclIterator(t).member.nameExact(fieldName)) + .typeFullNameNot("ANY") + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .toSet + + protected def visitReturns(ret: Return): Unit = + val m = ret.method + val existingTypes = mutable.HashSet.from( + (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) + .filterNot(_ == "ANY") + ) + @tailrec + def extractTypes(xs: List[CfgNode]): Set[String] = xs match + case ::(head: Literal, Nil) if head.typeFullName != "ANY" => + Set(head.typeFullName) + case ::(head: Call, Nil) if head.name == Operators.fieldAccess => + val fieldAccess = new FieldAccess(head) + val (sym, ts) = getSymbolFromCall(fieldAccess) + val cpgTypes = cpg.typeDecl + .fullNameExact(ts.map(_.compUnitFullName).toSeq*) + .member + .nameExact(sym.identifier) + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .filterNot { x => x == "ANY" || x == "this" } + .toSet + if cpgTypes.nonEmpty then cpgTypes + else symbolTable.get(sym) + case ::(head: Call, Nil) if symbolTable.contains(head) => + val callPaths = symbolTable.get(head) + val returnValues = methodReturnValues(callPaths.toSeq) + if returnValues.isEmpty then + callPaths.map(c => s"$c$pathSep${XTypeRecovery.DummyReturnType}") + else + returnValues + case ::(head: Call, Nil) if head.argumentOut.headOption.exists(symbolTable.contains) => + symbolTable + .get(head.argumentOut.head) + .map(t => + Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) + ) + case ::(identifier: Identifier, Nil) if symbolTable.contains(identifier) => + symbolTable.get(identifier) + case ::(head: Call, Nil) => + extractTypes(head.argument.l) + case _ => Set.empty + val returnTypes = extractTypes(ret.argumentOut.l) + existingTypes.addAll(returnTypes) + builder.setNodeProperty( + ret.method.methodReturn, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + existingTypes + ) + end visitReturns + + /** Using an entry from the symbol table, will queue the CPG modification to persist the + * recovered type information. + */ + protected def setTypeInformation(): Unit = + cu.ast + .collect { + case n: Local => n + case n: Call => n + case n: Expression => n + case n: MethodParameterIn if state.isFinalIteration => n + case n: MethodReturn if state.isFinalIteration => n } + .foreach { + case x: Local if symbolTable.contains(x) => + storeNodeTypeInfo(x, symbolTable.get(x).toSeq) + case x: MethodParameterIn => setTypeFromTypeHints(x) + case x: MethodReturn => setTypeFromTypeHints(x) + case x: Identifier if symbolTable.contains(x) => + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) + case x: Call if symbolTable.contains(x) => + val typs = + if state.config.enabledDummyTypes then symbolTable.get(x).toSeq + else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq + storeCallTypeInfo(x, typs) + case x: Identifier + if symbolTable.contains(CallAlias(x.name)) && x.inCall.nonEmpty => + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) + case x: Call if x.argument.headOption.exists(symbolTable.contains) => + setTypeInformationForRecCall(x, Option(x), x.argument.l) + case _ => + } + // Set types in an atomic way + newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } + end setTypeInformation + + protected def createCallFromIdentifierTypeFullName( + typeFullName: String, + callName: String + ): String = + s"$typeFullName$pathSep$callName" + + /** Sets type information for a receiver/call pattern. + */ + private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = + (n, ms) match + // Case 1: 'call' is an assignment from some dynamic dispatch call + case (Some(call: Call), ::(i: Identifier, ::(c: Call, _))) + if call.name == Operators.assignment => + setTypeForIdentifierAssignedToCall(call, i, c) + // Case 1: 'call' is an assignment from some other data structure + case (Some(call: Call), ::(i: Identifier, _)) if call.name == Operators.assignment => + setTypeForIdentifierAssignedToDefault(call, i) + // Case 2: 'i' is the receiver of 'call' + case (Some(call: Call), ::(i: Identifier, _)) if call.name != Operators.fieldAccess => + setTypeForDynamicDispatchCall(call, i) + // Case 3: 'i' is the receiver for a field access on member 'f' + case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _))) + if fieldAccess.name == Operators.fieldAccess => + setTypeForFieldAccess(new FieldAccess(fieldAccess), i, f) + case _ => + // Handle the node itself + x match + case c: Call if c.name.startsWith(" + case _ => persistType(x, symbolTable.get(x)) + end setTypeInformationForRecCall + + protected def setTypeForFieldAccess( + fieldAccess: Call, + i: Identifier, + f: FieldIdentifier + ): Unit = + val idHints = if symbolTable.contains(i) then symbolTable.get(i) + else symbolTable.get(CallAlias(i.name)) + val callTypes = symbolTable.get(fieldAccess) + persistType(i, idHints) + persistType(fieldAccess, callTypes) + fieldAccess.astParent.iterator.isCall.headOption match + case Some(callFromFieldName) if symbolTable.contains(callFromFieldName) => + persistType(callFromFieldName, symbolTable.get(callFromFieldName)) + case _ => + // This field may be a function pointer + handlePotentialFunctionPointer(fieldAccess, idHints, f.canonicalName, Option(i.name)) + + protected def setTypeForDynamicDispatchCall(call: Call, i: Identifier): Unit = + val idHints = symbolTable.get(i) + val callTypes = symbolTable.get(call) + persistType(i, idHints) + if callTypes.isEmpty && !call.name.startsWith("") then + // For now, calls are treated as function pointers and thus the type should point to the method + persistType(call, idHints.map(t => createCallFromIdentifierTypeFullName(t, call.name))) + else + persistType(call, callTypes) + + protected def setTypeForIdentifierAssignedToDefault(call: Call, i: Identifier): Unit = + val idHints = symbolTable.get(i) + persistType(i, idHints) + persistType(call, idHints) + + protected def setTypeForIdentifierAssignedToCall(call: Call, i: Identifier, c: Call): Unit = + val idTypes = if symbolTable.contains(i) then symbolTable.get(i) + else symbolTable.get(CallAlias(i.name)) + val callTypes = symbolTable.get(c) + persistType(call, callTypes) + if idTypes.nonEmpty || callTypes.nonEmpty then + if idTypes.equals(callTypes) then + // Case 1.1: This is a function pointer or constructor + persistType(i, callTypes) + else + // Case 1.2: This is the return value of the function + persistType(i, idTypes) + + protected def setTypeFromTypeHints(n: StoredNode): Unit = + val types = n.getKnownTypes.filterNot(XTypeRecovery.isDummyType) + if types.nonEmpty then setTypes(n, types.toSeq) + + /** In the case this field access is a function pointer, we would want to make sure this has a + * method ref. + */ + private def handlePotentialFunctionPointer( + funcPtr: Expression, + baseTypes: Set[String], + funcName: String, + baseName: Option[String] = None + ): Unit = + // Sometimes the function identifier is an argument to the call itself as a "base". In this case we don't need + // a method ref. This happens in jssrc2cpg + if !funcPtr.astParent.iterator.collectAll[Call].exists(_.name == funcName) then + baseTypes + .map(t => if t.endsWith(funcName) then t else s"$t$pathSep$funcName") + .flatMap(cpg.method.fullNameExact) + .filterNot(m => + addedNodes.contains( + s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${m.fullName}" + ) + ) + .map(m => + m -> createMethodRef( + baseName, + funcName, + m.fullName, + funcPtr.lineNumber, + funcPtr.columnNumber + ) + ) + .foreach { case (m, mRef) => + funcPtr.astParent + .filterNot( + _.astChildren.isMethodRef.exists(_.methodFullName == mRef.methodFullName) + ) + .foreach { inCall => + state.changesWereMade.compareAndSet(false, true) + integrateMethodRef(funcPtr, m, mRef, inCall) + } + } + + private def createMethodRef( + baseName: Option[String], + funcName: String, + methodFullName: String, + lineNo: Option[Integer], + columnNo: Option[Integer] + ): NewMethodRef = + NewMethodRef() + .code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName") + .methodFullName(methodFullName) + .lineNumber(lineNo) + .columnNumber(columnNo) + + /** Integrate this method ref node into the CPG according to schema rules. Since we're adding + * this after the base passes, we need to add the necessary linking manually. + */ + private def integrateMethodRef( + funcPtr: Expression, + m: Method, + mRef: NewMethodRef, + inCall: AstNode + ) = + builder.addNode(mRef) + builder.addEdge(mRef, m, EdgeTypes.REF) + builder.addEdge(inCall, mRef, EdgeTypes.AST) + builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS) + inCall match + case x: Call => + builder.addEdge(x, mRef, EdgeTypes.ARGUMENT) + mRef.argumentIndex(x.argumentOut.size + 1) + case x => + mRef.argumentIndex(x.astChildren.size + 1) + addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}") + + protected def persistType(x: StoredNode, types: Set[String]): Unit = + val filteredTypes = if state.config.enabledDummyTypes then types + else types.filterNot(XTypeRecovery.isDummyType) + if filteredTypes.nonEmpty then + storeNodeTypeInfo(x, filteredTypes.toSeq) + x match + case i: Identifier if symbolTable.contains(i) => + if isField(i) then persistMemberType(i, filteredTypes) + handlePotentialFunctionPointer(i, filteredTypes, i.name) + case _ => + + private def persistMemberType(i: Identifier, types: Set[String]): Unit = + getLocalMember(i) match + case Some(m) => storeNodeTypeInfo(m, types.toSeq) + case None => + + /** Type decls where member access are required need to point to the correct type decl that + * holds said members. This allows implementations to use the type names to find the correct + * type holding members. + * @param typeFullName + * the type full name. + * @return + * the type full name that has member children. + */ + protected def typeDeclIterator(typeFullName: String): Iterator[TypeDecl] = + cpg.typeDecl.fullNameExact(typeFullName) + + /** Given a type full name and member name, will persist the given types to the member. + * @param typeFullName + * the type full name. + * @param memberName + * the member name. + * @param types + * the types to associate. + */ + protected def persistMemberWithTypeDecl( + typeFullName: String, + memberName: String, + types: Set[String] + ): Unit = + typeDeclIterator(typeFullName).member.nameExact(memberName).headOption.foreach { m => + storeNodeTypeInfo(m, types.toSeq) } - } - } - - private def createMethodRef( - baseName: Option[String], - funcName: String, - methodFullName: String, - lineNo: Option[Integer], - columnNo: Option[Integer] - ): NewMethodRef = - NewMethodRef() - .code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName") - .methodFullName(methodFullName) - .lineNumber(lineNo) - .columnNumber(columnNo) - - /** Integrate this method ref node into the CPG according to schema rules. Since we're adding this after the base - * passes, we need to add the necessary linking manually. - */ - private def integrateMethodRef(funcPtr: Expression, m: Method, mRef: NewMethodRef, inCall: AstNode) = { - builder.addNode(mRef) - builder.addEdge(mRef, m, EdgeTypes.REF) - builder.addEdge(inCall, mRef, EdgeTypes.AST) - builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS) - inCall match { - case x: Call => - builder.addEdge(x, mRef, EdgeTypes.ARGUMENT) - mRef.argumentIndex(x.argumentOut.size + 1) - case x => - mRef.argumentIndex(x.astChildren.size + 1) - } - addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}") - } - - protected def persistType(x: StoredNode, types: Set[String]): Unit = { - val filteredTypes = if (state.config.enabledDummyTypes) types else types.filterNot(XTypeRecovery.isDummyType) - if (filteredTypes.nonEmpty) { - storeNodeTypeInfo(x, filteredTypes.toSeq) - x match { - case i: Identifier if symbolTable.contains(i) => - if (isField(i)) persistMemberType(i, filteredTypes) - handlePotentialFunctionPointer(i, filteredTypes, i.name) - case _ => - } - } - } - - private def persistMemberType(i: Identifier, types: Set[String]): Unit = { - getLocalMember(i) match { - case Some(m) => storeNodeTypeInfo(m, types.toSeq) - case None => - } - } - - /** Type decls where member access are required need to point to the correct type decl that holds said members. This - * allows implementations to use the type names to find the correct type holding members. - * @param typeFullName - * the type full name. - * @return - * the type full name that has member children. - */ - protected def typeDeclIterator(typeFullName: String): Iterator[TypeDecl] = cpg.typeDecl.fullNameExact(typeFullName) - - /** Given a type full name and member name, will persist the given types to the member. - * @param typeFullName - * the type full name. - * @param memberName - * the member name. - * @param types - * the types to associate. - */ - protected def persistMemberWithTypeDecl(typeFullName: String, memberName: String, types: Set[String]): Unit = - typeDeclIterator(typeFullName).member.nameExact(memberName).headOption.foreach { m => - storeNodeTypeInfo(m, types.toSeq) - } - - /** Given an identifier that has been determined to be a field, an attempt is made to get the corresponding member. - * This implementation follows more the way dynamic languages define method/type relations. - * @param i - * the identifier. - * @return - * the corresponding member, if found - */ - protected def getLocalMember(i: Identifier): Option[Member] = - typeDeclIterator(i.method.typeDecl.fullName.headOption.getOrElse(i.method.fullName)).member - .nameExact(i.name) - .headOption - - private def storeNodeTypeInfo(storedNode: StoredNode, types: Seq[String]): Unit = { - lazy val existingTypes = storedNode.getKnownTypes - - if (types.nonEmpty && types.toSet != existingTypes) { - storedNode match { - case m: Member => - // To avoid overwriting member updates, we store them elsewhere until the end - newTypesForMembers.updateWith(m) { - case Some(ts) => Some(ts ++ types) - case None => Some(types.toSet) - } - case i: Identifier => storeIdentifierTypeInfo(i, types) - case l: Local => storeLocalTypeInfo(l, types) - case c: Call if !c.name.startsWith("") => storeCallTypeInfo(c, types) - case _: Call => - case n => - state.changesWereMade.compareAndSet(false, true) - setTypes(n, types) - } - } - } - - protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = - if (types.nonEmpty) { - state.changesWereMade.compareAndSet(false, true) - builder.setNodeProperty( - c, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - (c.dynamicTypeHintFullName ++ types).distinct - ) - } - - /** Allows one to modify the types assigned to identifiers. - */ - protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = - storeDefaultTypeInfo(i, types) - - /** Allows one to modify the types assigned to nodes otherwise. - */ - protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = - if (types.toSet != n.getKnownTypes) { - state.changesWereMade.compareAndSet(false, true) - setTypes(n, (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct) - } - - /** If there is only 1 type hint then this is set to the `typeFullName` property and `dynamicTypeHintFullName` is - * cleared. If not then `dynamicTypeHintFullName` is set to the types. - */ - protected def setTypes(n: StoredNode, types: Seq[String]): Unit = - if (types.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) - else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) - - /** Allows one to modify the types assigned to locals. - */ - protected def storeLocalTypeInfo(l: Local, types: Seq[String]): Unit = { - storeDefaultTypeInfo(l, if (state.config.enabledDummyTypes) types else types.filterNot(XTypeRecovery.isDummyType)) - } - - /** Allows an implementation to perform an operation once type persistence is complete. - */ - protected def postSetTypeInformation(): Unit = {} - - private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r - - // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of - // the symbol table then perhaps this would work out better - implicit class AllNodeTypesFromNodeExt(x: StoredNode) { - def allTypes: Iterator[String] = (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - Seq.empty - )).iterator - - def getKnownTypes: Set[String] = { - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - } - - implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]) { - def allTypes: Iterator[String] = x.flatMap(_.allTypes) - - def getKnownTypes: Set[String] = - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - -} + + /** Given an identifier that has been determined to be a field, an attempt is made to get the + * corresponding member. This implementation follows more the way dynamic languages define + * method/type relations. + * @param i + * the identifier. + * @return + * the corresponding member, if found + */ + protected def getLocalMember(i: Identifier): Option[Member] = + typeDeclIterator(i.method.typeDecl.fullName.headOption.getOrElse(i.method.fullName)).member + .nameExact(i.name) + .headOption + + private def storeNodeTypeInfo(storedNode: StoredNode, types: Seq[String]): Unit = + lazy val existingTypes = storedNode.getKnownTypes + + if types.nonEmpty && types.toSet != existingTypes then + storedNode match + case m: Member => + // To avoid overwriting member updates, we store them elsewhere until the end + newTypesForMembers.updateWith(m) { + case Some(ts) => Some(ts ++ types) + case None => Some(types.toSet) + } + case i: Identifier => storeIdentifierTypeInfo(i, types) + case l: Local => storeLocalTypeInfo(l, types) + case c: Call if !c.name.startsWith("") => storeCallTypeInfo(c, types) + case _: Call => + case n => + state.changesWereMade.compareAndSet(false, true) + setTypes(n, types) + + protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = + if types.nonEmpty then + state.changesWereMade.compareAndSet(false, true) + builder.setNodeProperty( + c, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + (c.dynamicTypeHintFullName ++ types).distinct + ) + + /** Allows one to modify the types assigned to identifiers. + */ + protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + storeDefaultTypeInfo(i, types) + + /** Allows one to modify the types assigned to nodes otherwise. + */ + protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = + if types.toSet != n.getKnownTypes then + state.changesWereMade.compareAndSet(false, true) + setTypes( + n, + (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct + ) + + /** If there is only 1 type hint then this is set to the `typeFullName` property and + * `dynamicTypeHintFullName` is cleared. If not then `dynamicTypeHintFullName` is set to the + * types. + */ + protected def setTypes(n: StoredNode, types: Seq[String]): Unit = + if types.size == 1 then builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) + else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) + + /** Allows one to modify the types assigned to locals. + */ + protected def storeLocalTypeInfo(l: Local, types: Seq[String]): Unit = + storeDefaultTypeInfo( + l, + if state.config.enabledDummyTypes then types + else types.filterNot(XTypeRecovery.isDummyType) + ) + + /** Allows an implementation to perform an operation once type persistence is complete. + */ + protected def postSetTypeInformation(): Unit = {} + + private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r + + // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of + // the symbol table then perhaps this would work out better + implicit class AllNodeTypesFromNodeExt(x: StoredNode): + def allTypes: Iterator[String] = + (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + Seq.empty + )).iterator + + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(unknownTypePattern.matches) + + implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]): + def allTypes: Iterator[String] = x.flatMap(_.allTypes) + + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(unknownTypePattern.matches) +end RecoverForXCompilationUnit diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala index c580f1fd..e78e82c9 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala @@ -14,168 +14,236 @@ import scala.io.Source /** Creates tags on typeDecl and call nodes based on a cdx document */ -class CdxPass(atom: Cpg) extends CpgPass(atom) { - - val language: String = atom.metaData.language.head - - // Number of tags needed - private val TAGS_COUNT: Int = 2 - - // Number of dots to use in the package namespace - // Example: org.apache.logging.* would be used for tagging purposes - private val PKG_NS_SIZE: Int = 3 - - // tags list as a seed - private val keywords: List[String] = Source.fromResource("tags-vocab.txt").getLines.toList - - // FIXME: Replace these with semantic fingerprints - private def JS_REQUEST_PATTERNS = - Array( - "(?s)(?i).*(req|ctx|context)\\.(originalUrl|path|protocol|route|secure|signedCookies|stale|subdomains|xhr|app|pipe|file|files|baseUrl|fresh|hostname|ip|url|ips|method|body|param|params|query|cookies|request).*" - ) - - private def JS_RESPONSE_PATTERNS = - Array( - "(?s)(?i).*(res|ctx|context)\\.(append|attachment|body|cookie|download|end|format|json|jsonp|links|location|redirect|render|send|sendFile|sendStatus|set|vary).*", - "(?s)(?i).*res\\.(set|writeHead|setHeader).*", - "(?s)(?i).*(db|dao|mongo|mongoclient).*", - "(?s)(?i).*(\\s|\\.)(list|create|upload|delete|execute|command|invoke|submit|send)" - ) - - private def PY_REQUEST_PATTERNS = Array(".*views.py:.*") - - private def containsRegex(str: String) = Pattern.quote(str) == str || str.contains("*") - - private val BOM_JSON_FILE = ".*(bom|cdx).json" - - private def toPyModuleForm(str: String) = { - if (str.nonEmpty) { - val tmpParts = str.split("\\.") - if (str.count(_ == '.') > 1) s"${tmpParts.take(2).mkString(Pattern.quote(File.separator))}.*" - else if (str.count(_ == '.') == 1) s"${tmpParts.head}.py:.*" - else str - } else if (str.nonEmpty) { - s"$str${Pattern.quote(File.separator)}.*" - } else { - str - } - } - - override def run(dstGraph: DiffGraphBuilder): Unit = { - atom.configFile.name(BOM_JSON_FILE).content.foreach { cdxData => - val cdxJson = parse(cdxData).getOrElse(Json.Null) - val cursor: HCursor = cdxJson.hcursor - val components = cursor.downField("components").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val donePkgs = mutable.Map[String, Boolean]() - if (language == Languages.JSSRC || language == Languages.JAVASCRIPT) { - JS_REQUEST_PATTERNS.foreach(p => atom.call.code(p).newTagNode("framework-input").store()(dstGraph)) - JS_RESPONSE_PATTERNS.foreach(p => atom.call.code(p).newTagNode("framework-output").store()(dstGraph)) - } - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) { - PY_REQUEST_PATTERNS - .foreach(p => atom.method.fullName(p).parameter.newTagNode("framework-input").store()(dstGraph)) - } - components.foreach { comp => - val PURL_TYPE = "purl" - val compPurl = comp.hcursor.downField(PURL_TYPE).as[String].getOrElse("") - val compType = comp.hcursor.downField("type").as[String].getOrElse("") - val compDescription: String = comp.hcursor.downField("description").as[String].getOrElse("") - val descTags = keywords.filter(k => compDescription.toLowerCase().contains(" " + k)).take(TAGS_COUNT) - val properties = comp.hcursor.downField("properties").focus.flatMap(_.asArray).getOrElse(Vector.empty) - properties.foreach { ns => - val nsstr = ns.hcursor.downField("value").as[String].getOrElse("") - val nsname = ns.hcursor.downField("name").as[String].getOrElse("") - // Skip the SrcFile property - if (nsname != "SrcFile") { - nsstr - .split("(\n|,)") - .filterNot(_.startsWith("java.")) - .filterNot(_.startsWith("com.sun")) - .filterNot(_.contains("test")) - .filterNot(_.contains("mock")) - .foreach { (pkg: String) => - var bpkg = pkg.takeWhile(_ != '$') - if (language == Languages.JAVA || language == Languages.JAVASRC) { - bpkg = bpkg.split("\\.").take(PKG_NS_SIZE).mkString(".").concat(".*") - bpkg = bpkg.replace(File.separator, Pattern.quote(File.separator)) - } - if (language == Languages.JSSRC || language == Languages.JAVASCRIPT) { - bpkg = s".*${bpkg}.*" - bpkg = bpkg.replace(File.separator, Pattern.quote(File.separator)) +class CdxPass(atom: Cpg) extends CpgPass(atom): + + val language: String = atom.metaData.language.head + + // Number of tags needed + private val TAGS_COUNT: Int = 2 + + // Number of dots to use in the package namespace + // Example: org.apache.logging.* would be used for tagging purposes + private val PKG_NS_SIZE: Int = 3 + + // tags list as a seed + private val keywords: List[String] = Source.fromResource("tags-vocab.txt").getLines.toList + + // FIXME: Replace these with semantic fingerprints + private def JS_REQUEST_PATTERNS = + Array( + "(?s)(?i).*(req|ctx|context)\\.(originalUrl|path|protocol|route|secure|signedCookies|stale|subdomains|xhr|app|pipe|file|files|baseUrl|fresh|hostname|ip|url|ips|method|body|param|params|query|cookies|request).*" + ) + + private def JS_RESPONSE_PATTERNS = + Array( + "(?s)(?i).*(res|ctx|context)\\.(append|attachment|body|cookie|download|end|format|json|jsonp|links|location|redirect|render|send|sendFile|sendStatus|set|vary).*", + "(?s)(?i).*res\\.(set|writeHead|setHeader).*", + "(?s)(?i).*(db|dao|mongo|mongoclient).*", + "(?s)(?i).*(\\s|\\.)(list|create|upload|delete|execute|command|invoke|submit|send)" + ) + + private def PY_REQUEST_PATTERNS = Array(".*views.py:.*") + + private def containsRegex(str: String) = Pattern.quote(str) == str || str.contains("*") + + private val BOM_JSON_FILE = ".*(bom|cdx).json" + + private def toPyModuleForm(str: String) = + if str.nonEmpty then + val tmpParts = str.split("\\.") + if str.count(_ == '.') > 1 then + s"${tmpParts.take(2).mkString(Pattern.quote(File.separator))}.*" + else if str.count(_ == '.') == 1 then s"${tmpParts.head}.py:.*" + else str + else if str.nonEmpty then + s"$str${Pattern.quote(File.separator)}.*" + else + str + + override def run(dstGraph: DiffGraphBuilder): Unit = + atom.configFile.name(BOM_JSON_FILE).content.foreach { cdxData => + val cdxJson = parse(cdxData).getOrElse(Json.Null) + val cursor: HCursor = cdxJson.hcursor + val components = + cursor.downField("components").focus.flatMap(_.asArray).getOrElse(Vector.empty) + val donePkgs = mutable.Map[String, Boolean]() + if language == Languages.JSSRC || language == Languages.JAVASCRIPT then + JS_REQUEST_PATTERNS.foreach(p => + atom.call.code(p).newTagNode("framework-input").store()(dstGraph) + ) + JS_RESPONSE_PATTERNS.foreach(p => + atom.call.code(p).newTagNode("framework-output").store()(dstGraph) + ) + if language == Languages.PYTHON || language == Languages.PYTHONSRC then + PY_REQUEST_PATTERNS + .foreach(p => + atom.method.fullName(p).parameter.newTagNode("framework-input").store()( + dstGraph + ) + ) + components.foreach { comp => + val PURL_TYPE = "purl" + val compPurl = comp.hcursor.downField(PURL_TYPE).as[String].getOrElse("") + val compType = comp.hcursor.downField("type").as[String].getOrElse("") + val compDescription: String = + comp.hcursor.downField("description").as[String].getOrElse("") + val descTags = keywords.filter(k => + compDescription.toLowerCase().contains(" " + k) + ).take(TAGS_COUNT) + val properties = comp.hcursor.downField("properties").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + properties.foreach { ns => + val nsstr = ns.hcursor.downField("value").as[String].getOrElse("") + val nsname = ns.hcursor.downField("name").as[String].getOrElse("") + // Skip the SrcFile property + if nsname != "SrcFile" then + nsstr + .split("(\n|,)") + .filterNot(_.startsWith("java.")) + .filterNot(_.startsWith("com.sun")) + .filterNot(_.contains("test")) + .filterNot(_.contains("mock")) + .foreach { (pkg: String) => + var bpkg = pkg.takeWhile(_ != '$') + if language == Languages.JAVA || language == Languages.JAVASRC then + bpkg = bpkg.split("\\.").take(PKG_NS_SIZE).mkString(".").concat( + ".*" + ) + bpkg = + bpkg.replace(File.separator, Pattern.quote(File.separator)) + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + bpkg = s".*${bpkg}.*" + bpkg = + bpkg.replace(File.separator, Pattern.quote(File.separator)) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then bpkg = toPyModuleForm(bpkg) + if bpkg.nonEmpty && !donePkgs.contains(bpkg) then + donePkgs.put(bpkg, true) + if !containsRegex(bpkg) then + atom.call.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.identifier.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.parameter.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(compPurl).store()(dstGraph) + else + atom.call.typeFullName(bpkg).newTagNode(compPurl).store()( + dstGraph + ) + atom.identifier.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.parameter.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.fullName(bpkg).newTagNode(compPurl).store()( + dstGraph + ) + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + atom.call.code(bpkg).argument.newTagNode( + compPurl + ).store()(dstGraph) + atom.identifier.code(bpkg).newTagNode(compPurl).store()( + dstGraph + ) + atom.identifier.code(bpkg).inCall.newTagNode( + compPurl + ).store()(dstGraph) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then + atom.call.where( + _.methodFullName(bpkg) + ).argument.newTagNode(compPurl).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + end if + if compType != "library" then + if !containsRegex(bpkg) then + atom.call.typeFullNameExact(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.call.typeFullNameExact(bpkg).receiver.newTagNode( + s"$compType-value" + ).store()(dstGraph) + atom.method.parameter.typeFullNameExact( + bpkg + ).newTagNode(compType).store()(dstGraph) + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(compType).store()(dstGraph) + else + atom.call.typeFullName(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.call.typeFullName(bpkg).receiver.newTagNode( + s"$compType-value" + ).store()(dstGraph) + atom.method.parameter.typeFullName(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.method.fullName(bpkg).newTagNode(compType).store()( + dstGraph + ) + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + atom.call.code(bpkg).argument.newTagNode( + compType + ).store()(dstGraph) + atom.identifier.code(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.identifier.code(bpkg).inCall.newTagNode( + compType + ).store()(dstGraph) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then + atom.call.where( + _.methodFullName(bpkg) + ).argument.newTagNode(compType).store()(dstGraph) + end if + if compType == "framework" then + def frameworkAnnotatedMethod = atom.annotation + .fullName(bpkg) + .method + + frameworkAnnotatedMethod.parameter + .newTagNode(s"$compType-input") + .store()(dstGraph) + atom.ret + .where(_.method.annotation.fullName(bpkg)) + .newTagNode(s"$compType-output") + .store()(dstGraph) + descTags.foreach { t => + atom.call.typeFullName(bpkg).newTagNode(t).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode(t).store()( + dstGraph + ) + atom.method.parameter.typeFullName(bpkg).newTagNode( + t + ).store()(dstGraph) + if !containsRegex(bpkg) then + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(t).store()(dstGraph) + else + atom.method.fullName(bpkg).newTagNode(t).store()( + dstGraph + ) + } + end if + } + end if } - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) bpkg = toPyModuleForm(bpkg) - if (bpkg.nonEmpty && !donePkgs.contains(bpkg)) { - donePkgs.put(bpkg, true) - if (!containsRegex(bpkg)) { - atom.call.typeFullNameExact(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.identifier.typeFullNameExact(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.method.parameter.typeFullNameExact(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.method.fullName(s"${Pattern.quote(bpkg)}.*").newTagNode(compPurl).store()(dstGraph) - } else { - atom.call.typeFullName(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.method.parameter.typeFullName(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.method.fullName(bpkg).newTagNode(compPurl).store()(dstGraph) - if (language == Languages.JSSRC || language == Languages.JAVASCRIPT) { - atom.call.code(bpkg).argument.newTagNode(compPurl).store()(dstGraph) - atom.identifier.code(bpkg).newTagNode(compPurl).store()(dstGraph) - atom.identifier.code(bpkg).inCall.newTagNode(compPurl).store()(dstGraph) - } - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) { - atom.call.where(_.methodFullName(bpkg)).argument.newTagNode(compPurl).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode(compPurl).store()(dstGraph) - } - } - if (compType != "library") { - if (!containsRegex(bpkg)) { - atom.call.typeFullNameExact(bpkg).newTagNode(compType).store()(dstGraph) - atom.call.typeFullNameExact(bpkg).receiver.newTagNode(s"$compType-value").store()(dstGraph) - atom.method.parameter.typeFullNameExact(bpkg).newTagNode(compType).store()(dstGraph) - atom.method.fullName(s"${Pattern.quote(bpkg)}.*").newTagNode(compType).store()(dstGraph) - } else { - atom.call.typeFullName(bpkg).newTagNode(compType).store()(dstGraph) - atom.call.typeFullName(bpkg).receiver.newTagNode(s"$compType-value").store()(dstGraph) - atom.method.parameter.typeFullName(bpkg).newTagNode(compType).store()(dstGraph) - atom.method.fullName(bpkg).newTagNode(compType).store()(dstGraph) - if (language == Languages.JSSRC || language == Languages.JAVASCRIPT) { - atom.call.code(bpkg).argument.newTagNode(compType).store()(dstGraph) - atom.identifier.code(bpkg).newTagNode(compType).store()(dstGraph) - atom.identifier.code(bpkg).inCall.newTagNode(compType).store()(dstGraph) - } - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) { - atom.call.where(_.methodFullName(bpkg)).argument.newTagNode(compType).store()(dstGraph) - } - } - } - if (compType == "framework") { - def frameworkAnnotatedMethod = atom.annotation - .fullName(bpkg) - .method - - frameworkAnnotatedMethod.parameter - .newTagNode(s"$compType-input") - .store()(dstGraph) - atom.ret - .where(_.method.annotation.fullName(bpkg)) - .newTagNode(s"$compType-output") - .store()(dstGraph) - } - descTags.foreach { t => - atom.call.typeFullName(bpkg).newTagNode(t).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode(t).store()(dstGraph) - atom.method.parameter.typeFullName(bpkg).newTagNode(t).store()(dstGraph) - if (!containsRegex(bpkg)) { - atom.method.fullName(s"${Pattern.quote(bpkg)}.*").newTagNode(t).store()(dstGraph) - } else { - atom.method.fullName(bpkg).newTagNode(t).store()(dstGraph) - } - } - } - } - } + } } - } - } - } - -} +end CdxPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala index b272de54..6370214a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala @@ -12,98 +12,117 @@ import java.util.regex.Pattern /** Creates tags on any node */ -class ChennaiTagsPass(atom: Cpg) extends CpgPass(atom) { +class ChennaiTagsPass(atom: Cpg) extends CpgPass(atom): - val language: String = atom.metaData.language.head - private val FRAMEWORK_ROUTE = "framework-route" - private val FRAMEWORK_INPUT = "framework-input" - private val FRAMEWORK_OUTPUT = "framework-output" + val language: String = atom.metaData.language.head + private val FRAMEWORK_ROUTE = "framework-route" + private val FRAMEWORK_INPUT = "framework-input" + private val FRAMEWORK_OUTPUT = "framework-output" - private val PYTHON_ROUTES_CALL_REGEXES = - Array("django/(conf/)?urls.py:.(path|re_path|url).*", ".*(route|web\\.|add_resource).*") - private val PYTHON_ROUTES_DECORATORS_REGEXES = Array( - ".*(route|endpoint|_request|require_http_methods|require_GET|require_POST|require_safe|_required)\\(.*", - ".*def\\s(get|post|put)\\(.*" - ) - private val HTTP_METHODS_REGEX = ".*(request|session)\\.(args|get|post|put|form).*" - private def tagPythonRoutes(dstGraph: DiffGraphBuilder): Unit = { - PYTHON_ROUTES_CALL_REGEXES.foreach { r => - atom.call - .where(_.methodFullName(r)) - .argument - .isLiteral - .newTagNode(FRAMEWORK_ROUTE) - .store()(dstGraph) - } - PYTHON_ROUTES_DECORATORS_REGEXES.foreach { r => - def decoratedMethods = atom.methodRef - .where(_.inCall.code(r).argument) - ._refOut - .collectAll[Method] - decoratedMethods.call.assignment - .code(HTTP_METHODS_REGEX) - .argument - .isIdentifier - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - decoratedMethods - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - decoratedMethods.parameter - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - } - } - - override def run(dstGraph: DiffGraphBuilder): Unit = { - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) tagPythonRoutes(dstGraph) - atom.configFile("chennai.json").content.foreach { cdxData => - val ctagsJson = parse(cdxData).getOrElse(Json.Null) - val cursor: HCursor = ctagsJson.hcursor - val tags = cursor.downField("tags").focus.flatMap(_.asArray).getOrElse(Vector.empty) - tags.foreach { comp => - val tagName = comp.hcursor.downField("name").as[String].getOrElse("") - val tagParams = comp.hcursor.downField("parameters").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val tagMethods = comp.hcursor.downField("methods").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val tagTypes = comp.hcursor.downField("types").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val tagFiles = comp.hcursor.downField("files").focus.flatMap(_.asArray).getOrElse(Vector.empty) - tagParams.foreach { paramName => - val pn = paramName.asString.getOrElse("") - if (pn.nonEmpty) { - atom.method.parameter.typeFullNameExact(pn).newTagNode(tagName).store()(dstGraph) - if (!pn.contains("[") && !pn.contains("*")) - atom.method.parameter.typeFullName(s".*${Pattern.quote(pn)}.*").newTagNode(tagName).store()(dstGraph) - } - } - tagMethods.foreach { methodName => - val mn = methodName.asString.getOrElse("") - if (mn.nonEmpty) { - atom.method.fullNameExact(mn).newTagNode(tagName).store()(dstGraph) - if (!mn.contains("[") && !mn.contains("*")) - atom.method.fullName(s".*${Pattern.quote(mn)}.*").newTagNode(tagName).store()(dstGraph) - } + private val PYTHON_ROUTES_CALL_REGEXES = + Array( + "django/(conf/)?urls.py:.(path|re_path|url).*", + ".*(route|web\\.|add_resource).*" + ) + private val PYTHON_ROUTES_DECORATORS_REGEXES = Array( + ".*(route|endpoint|_request|require_http_methods|require_GET|require_POST|require_safe|_required)\\(.*", + ".*def\\s(get|post|put)\\(.*" + ) + private val HTTP_METHODS_REGEX = ".*(request|session)\\.(args|get|post|put|form).*" + private def tagPythonRoutes(dstGraph: DiffGraphBuilder): Unit = + PYTHON_ROUTES_CALL_REGEXES.foreach { r => + atom.call + .where(_.methodFullName(r)) + .argument + .isLiteral + .newTagNode(FRAMEWORK_ROUTE) + .store()(dstGraph) } - tagTypes.foreach { typeName => - val tn = typeName.asString.getOrElse("") - if (tn.nonEmpty) { - atom.method.parameter.typeFullNameExact(tn).newTagNode(tagName).store()(dstGraph) - if (!tn.contains("[") && !tn.contains("*")) - atom.method.parameter.typeFullName(s".*${Pattern.quote(tn)}.*").newTagNode(tagName).store()(dstGraph) - atom.call.typeFullNameExact(tn).newTagNode(tagName).store()(dstGraph) - if (!tn.contains("[") && !tn.contains("*")) - atom.call.typeFullName(s".*${Pattern.quote(tn)}.*").newTagNode(tagName).store()(dstGraph) - } + PYTHON_ROUTES_DECORATORS_REGEXES.foreach { r => + def decoratedMethods = atom.methodRef + .where(_.inCall.code(r).argument) + ._refOut + .collectAll[Method] + decoratedMethods.call.assignment + .code(HTTP_METHODS_REGEX) + .argument + .isIdentifier + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) + decoratedMethods + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) + decoratedMethods.parameter + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) } - tagFiles.foreach { fileName => - val fn = fileName.asString.getOrElse("") - if (fn.nonEmpty) { - atom.file.nameExact(fn).newTagNode(tagName).store()(dstGraph) - if (!fn.contains("[") && !fn.contains("*")) - atom.file.name(s".*${Pattern.quote(fn)}.*").newTagNode(tagName).store()(dstGraph) - } - } - } - } - } + end tagPythonRoutes -} + override def run(dstGraph: DiffGraphBuilder): Unit = + if language == Languages.PYTHON || language == Languages.PYTHONSRC then + tagPythonRoutes(dstGraph) + atom.configFile("chennai.json").content.foreach { cdxData => + val ctagsJson = parse(cdxData).getOrElse(Json.Null) + val cursor: HCursor = ctagsJson.hcursor + val tags = cursor.downField("tags").focus.flatMap(_.asArray).getOrElse(Vector.empty) + tags.foreach { comp => + val tagName = comp.hcursor.downField("name").as[String].getOrElse("") + val tagParams = comp.hcursor.downField("parameters").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + val tagMethods = comp.hcursor.downField("methods").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + val tagTypes = + comp.hcursor.downField("types").focus.flatMap(_.asArray).getOrElse(Vector.empty) + val tagFiles = + comp.hcursor.downField("files").focus.flatMap(_.asArray).getOrElse(Vector.empty) + tagParams.foreach { paramName => + val pn = paramName.asString.getOrElse("") + if pn.nonEmpty then + atom.method.parameter.typeFullNameExact(pn).newTagNode(tagName).store()( + dstGraph + ) + if !pn.contains("[") && !pn.contains("*") then + atom.method.parameter.typeFullName( + s".*${Pattern.quote(pn)}.*" + ).newTagNode(tagName).store()(dstGraph) + } + tagMethods.foreach { methodName => + val mn = methodName.asString.getOrElse("") + if mn.nonEmpty then + atom.method.fullNameExact(mn).newTagNode(tagName).store()(dstGraph) + if !mn.contains("[") && !mn.contains("*") then + atom.method.fullName(s".*${Pattern.quote(mn)}.*").newTagNode( + tagName + ).store()(dstGraph) + } + tagTypes.foreach { typeName => + val tn = typeName.asString.getOrElse("") + if tn.nonEmpty then + atom.method.parameter.typeFullNameExact(tn).newTagNode(tagName).store()( + dstGraph + ) + if !tn.contains("[") && !tn.contains("*") then + atom.method.parameter.typeFullName( + s".*${Pattern.quote(tn)}.*" + ).newTagNode(tagName).store()(dstGraph) + atom.call.typeFullNameExact(tn).newTagNode(tagName).store()(dstGraph) + if !tn.contains("[") && !tn.contains("*") then + atom.call.typeFullName(s".*${Pattern.quote(tn)}.*").newTagNode( + tagName + ).store()(dstGraph) + } + tagFiles.foreach { fileName => + val fn = fileName.asString.getOrElse("") + if fn.nonEmpty then + atom.file.nameExact(fn).newTagNode(tagName).store()(dstGraph) + if !fn.contains("[") && !fn.contains("*") then + atom.file.name(s".*${Pattern.quote(fn)}.*").newTagNode(tagName).store()( + dstGraph + ) + } + } + } + end run +end ChennaiTagsPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala index 8a48e629..6088fb4d 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala @@ -6,22 +6,18 @@ import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes, PropertyN import io.shiftleft.passes.CpgPass import io.appthreat.x2cpg.utils.LinkingUtil -class AliasLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { +class AliasLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = { - // Create ALIAS_OF edges from TYPE_DECL nodes to TYPE - linkToMultiple( - cpg, - srcLabels = List(NodeTypes.TYPE_DECL), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.ALIAS_OF, - dstNodeMap = typeFullNameToNode(cpg, _), - getDstFullNames = (srcNode: TypeDecl) => { - srcNode.aliasTypeFullName - }, - dstFullNameKey = PropertyNames.ALIAS_TYPE_FULL_NAME, - dstGraph - ) - } - -} + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create ALIAS_OF edges from TYPE_DECL nodes to TYPE + linkToMultiple( + cpg, + srcLabels = List(NodeTypes.TYPE_DECL), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.ALIAS_OF, + dstNodeMap = typeFullNameToNode(cpg, _), + getDstFullNames = (srcNode: TypeDecl) => + srcNode.aliasTypeFullName, + dstFullNameKey = PropertyNames.ALIAS_TYPE_FULL_NAME, + dstGraph + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala index 4ad6ee46..a9a2ced5 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala @@ -8,25 +8,21 @@ import io.appthreat.x2cpg.utils.LinkingUtil /** Create INHERITS_FROM edges from `TYPE_DECL` nodes to `TYPE` nodes. */ -class TypeHierarchyPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { +class TypeHierarchyPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = { - linkToMultiple( - cpg, - srcLabels = List(NodeTypes.TYPE_DECL), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.INHERITS_FROM, - dstNodeMap = typeFullNameToNode(cpg, _), - getDstFullNames = (srcNode: TypeDecl) => { - if (srcNode.inheritsFromTypeFullName != null) { - srcNode.inheritsFromTypeFullName - } else { - Seq() - } - }, - dstFullNameKey = PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, - dstGraph - ) - } - -} + override def run(dstGraph: DiffGraphBuilder): Unit = + linkToMultiple( + cpg, + srcLabels = List(NodeTypes.TYPE_DECL), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.INHERITS_FROM, + dstNodeMap = typeFullNameToNode(cpg, _), + getDstFullNames = (srcNode: TypeDecl) => + if srcNode.inheritsFromTypeFullName != null then + srcNode.inheritsFromTypeFullName + else + Seq() + , + dstFullNameKey = PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, + dstGraph + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala index 0fbc92ea..06f5e268 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala @@ -3,32 +3,28 @@ package io.appthreat.x2cpg.utils import io.appthreat.x2cpg.Ast import io.shiftleft.codepropertygraph.generated.PropertyNames -object AstPropertiesUtil { +object AstPropertiesUtil: - implicit class RootProperties(val ast: Ast) extends AnyVal { + implicit class RootProperties(val ast: Ast) extends AnyVal: - private def rootProperty(propertyName: String): Option[String] = { - ast.root.flatMap(_.properties.get(propertyName).map(_.toString)) - } + private def rootProperty(propertyName: String): Option[String] = + ast.root.flatMap(_.properties.get(propertyName).map(_.toString)) - def rootType: Option[String] = rootProperty(PropertyNames.TYPE_FULL_NAME) + def rootType: Option[String] = rootProperty(PropertyNames.TYPE_FULL_NAME) - def rootCode: Option[String] = rootProperty(PropertyNames.CODE) + def rootCode: Option[String] = rootProperty(PropertyNames.CODE) - def rootName: Option[String] = rootProperty(PropertyNames.NAME) + def rootName: Option[String] = rootProperty(PropertyNames.NAME) - def rootCodeOrEmpty: String = rootCode.getOrElse("") + def rootCodeOrEmpty: String = rootCode.getOrElse("") - } + implicit class RootPropertiesOnSeq(val asts: Seq[Ast]) extends AnyVal: - implicit class RootPropertiesOnSeq(val asts: Seq[Ast]) extends AnyVal { + def rootType: Option[String] = asts.headOption.flatMap(_.rootType) - def rootType: Option[String] = asts.headOption.flatMap(_.rootType) + def rootCode: Option[String] = asts.headOption.flatMap(_.rootCode) - def rootCode: Option[String] = asts.headOption.flatMap(_.rootCode) + def rootName: Option[String] = asts.headOption.flatMap(_.rootName) - def rootName: Option[String] = asts.headOption.flatMap(_.rootName) - - def rootCodeOrEmpty: String = asts.rootCode.getOrElse("") - } -} + def rootCodeOrEmpty: String = asts.rootCode.getOrElse("") +end AstPropertiesUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala index 0265c29b..e9b9c638 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala @@ -4,41 +4,36 @@ import org.slf4j.LoggerFactory import java.nio.file.Paths -object Environment { +object Environment: - object OperatingSystemType extends Enumeration { - type OperatingSystemType = Value + object OperatingSystemType extends Enumeration: + type OperatingSystemType = Value - val Windows, Linux, Mac, Unknown = Value - } + val Windows, Linux, Mac, Unknown = Value - object ArchitectureType extends Enumeration { - type ArchitectureType = Value + object ArchitectureType extends Enumeration: + type ArchitectureType = Value - val X86, ARM = Value - } + val X86, ARM = Value - lazy val operatingSystem: OperatingSystemType.OperatingSystemType = - if (scala.util.Properties.isMac) OperatingSystemType.Mac - else if (scala.util.Properties.isLinux) OperatingSystemType.Linux - else if (scala.util.Properties.isWin) OperatingSystemType.Windows - else OperatingSystemType.Unknown + lazy val operatingSystem: OperatingSystemType.OperatingSystemType = + if scala.util.Properties.isMac then OperatingSystemType.Mac + else if scala.util.Properties.isLinux then OperatingSystemType.Linux + else if scala.util.Properties.isWin then OperatingSystemType.Windows + else OperatingSystemType.Unknown - lazy val architecture: ArchitectureType.ArchitectureType = - if (scala.util.Properties.propOrNone("os.arch").contains("aarch64")) ArchitectureType.ARM - // We do not distinguish between x86 and x64. E.g, a 64 bit Windows will always lie about - // this and will report x86 anyway for backwards compatibility with 32 bit software. - else ArchitectureType.X86 + lazy val architecture: ArchitectureType.ArchitectureType = + if scala.util.Properties.propOrNone("os.arch").contains("aarch64") then ArchitectureType.ARM + // We do not distinguish between x86 and x64. E.g, a 64 bit Windows will always lie about + // this and will report x86 anyway for backwards compatibility with 32 bit software. + else ArchitectureType.X86 - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - def pathExists(path: String): Boolean = { - if (!Paths.get(path).toFile.exists()) { - logger.debug(s"Input path '$path' does not exist!") - false - } else { - true - } - } - -} + def pathExists(path: String): Boolean = + if !Paths.get(path).toFile.exists() then + logger.debug(s"Input path '$path' does not exist!") + false + else + true +end Environment diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala index b710eb9e..bb0264ff 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala @@ -4,48 +4,52 @@ import java.util.concurrent.ConcurrentLinkedQueue import org.apache.commons.lang.StringUtils import scala.sys.process.{Process, ProcessLogger} import scala.util.{Failure, Success, Try} -import scala.jdk.CollectionConverters._ - -object ExternalCommand { - - private val IS_WIN: Boolean = - scala.util.Properties.isWin - - private val shellPrefix: Seq[String] = - if (IS_WIN) "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil - - def run(command: String, cwd: String, separateStdErr: Boolean = false): Try[Seq[String]] = { - val stdOutOutput = new ConcurrentLinkedQueue[String] - val stdErrOutput = if (separateStdErr) new ConcurrentLinkedQueue[String] else stdOutOutput - val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) - - Process(shellPrefix :+ command, new java.io.File(cwd)).!(processLogger) match { - case 0 => - Success(stdOutOutput.asScala.toSeq) - case _ => - Failure(new RuntimeException(stdErrOutput.asScala.mkString(System.lineSeparator()))) - } - } - - private val COMMAND_AND: String = " && " - - def toOSCommand(command: String): String = if (IS_WIN) command + ".cmd" else command - - def runMultiple(command: String, inDir: String = ".", extraEnv: Map[String, String] = Map.empty): Try[String] = { - val dir = new java.io.File(inDir) - val stdOutOutput = new ConcurrentLinkedQueue[String] - val stdErrOutput = new ConcurrentLinkedQueue[String] - val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) - val commands = command.split(COMMAND_AND).toSeq - commands.map { cmd => - val cmdWithQuotesAroundDir = StringUtils.replace(cmd, inDir, s"'$inDir'") - Try(Process(cmdWithQuotesAroundDir, dir, extraEnv.toList: _*).!(processLogger)).getOrElse(1) - }.sum match { - case 0 => - Success(stdOutOutput.asScala.mkString(System.lineSeparator())) - case _ => - val allOutput = stdOutOutput.asScala ++ stdErrOutput.asScala - Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) - } - } -} +import scala.jdk.CollectionConverters.* + +object ExternalCommand: + + private val IS_WIN: Boolean = + scala.util.Properties.isWin + + private val shellPrefix: Seq[String] = + if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil + + def run(command: String, cwd: String, separateStdErr: Boolean = false): Try[Seq[String]] = + val stdOutOutput = new ConcurrentLinkedQueue[String] + val stdErrOutput = + if separateStdErr then new ConcurrentLinkedQueue[String] else stdOutOutput + val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) + + Process(shellPrefix :+ command, new java.io.File(cwd)).!(processLogger) match + case 0 => + Success(stdOutOutput.asScala.toSeq) + case _ => + Failure(new RuntimeException(stdErrOutput.asScala.mkString(System.lineSeparator()))) + + private val COMMAND_AND: String = " && " + + def toOSCommand(command: String): String = if IS_WIN then command + ".cmd" else command + + def runMultiple( + command: String, + inDir: String = ".", + extraEnv: Map[String, String] = Map.empty + ): Try[String] = + val dir = new java.io.File(inDir) + val stdOutOutput = new ConcurrentLinkedQueue[String] + val stdErrOutput = new ConcurrentLinkedQueue[String] + val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) + val commands = command.split(COMMAND_AND).toSeq + commands.map { cmd => + val cmdWithQuotesAroundDir = StringUtils.replace(cmd, inDir, s"'$inDir'") + Try(Process(cmdWithQuotesAroundDir, dir, extraEnv.toList*).!(processLogger)).getOrElse( + 1 + ) + }.sum match + case 0 => + Success(stdOutOutput.asScala.mkString(System.lineSeparator())) + case _ => + val allOutput = stdOutOutput.asScala ++ stdErrOutput.asScala + Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) + end runMultiple +end ExternalCommand diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala index 408d3fcb..f9738d73 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala @@ -4,30 +4,28 @@ import java.nio.file.{Files, Path} import java.security.{DigestInputStream, MessageDigest} import scala.util.Using -object HashUtil { +object HashUtil: - def sha256(file: Path): String = - sha256(Seq(file)) + def sha256(file: Path): String = + sha256(Seq(file)) - def sha256(file: String): String = - sha256(Seq(Path.of(file))) + def sha256(file: String): String = + sha256(Seq(Path.of(file))) - def sha256(files: Seq[Path]): String = { - val md = MessageDigest.getInstance("SHA-256") - val buffer = new Array[Byte](4096) - files - .filterNot(p => isDirectory(p.toRealPath())) - .foreach { path => - Using.resource(new DigestInputStream(Files.newInputStream(path), md)) { dis => - while (dis.available() > 0) { - dis.read(buffer) - } - } - } - md.digest().map(b => String.format("%02x", Byte.box(b))).mkString - } + def sha256(files: Seq[Path]): String = + val md = MessageDigest.getInstance("SHA-256") + val buffer = new Array[Byte](4096) + files + .filterNot(p => isDirectory(p.toRealPath())) + .foreach { path => + Using.resource(new DigestInputStream(Files.newInputStream(path), md)) { dis => + while dis.available() > 0 do + dis.read(buffer) + } + } + md.digest().map(b => String.format("%02x", Byte.box(b))).mkString - private def isDirectory(path: Path): Boolean = - if (path == null || !Files.exists(path)) false - else Files.isDirectory(path) -} + private def isDirectory(path: Path): Boolean = + if path == null || !Files.exists(path) then false + else Files.isDirectory(path) +end HashUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala index 4ba88281..3c463556 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala @@ -14,162 +14,176 @@ import org.slf4j.LoggerFactory import overflowdb.NodeDb import overflowdb.PropertyKey import overflowdb.NodeRef -import overflowdb.traversal._ -import overflowdb.traversal.ChainedImplicitsTemp._ +import overflowdb.traversal.* +import overflowdb.traversal.ChainedImplicitsTemp.* import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -trait LinkingUtil { +trait LinkingUtil: - import overflowdb.BatchedUpdate.DiffGraphBuilder + import overflowdb.BatchedUpdate.DiffGraphBuilder - val logger: Logger = LoggerFactory.getLogger(classOf[LinkingUtil]) + val logger: Logger = LoggerFactory.getLogger(classOf[LinkingUtil]) - def typeDeclFullNameToNode(cpg: Cpg, x: String): Option[TypeDecl] = - nodesWithFullName(cpg, x).collectFirst { case x: TypeDecl => x } + def typeDeclFullNameToNode(cpg: Cpg, x: String): Option[TypeDecl] = + nodesWithFullName(cpg, x).collectFirst { case x: TypeDecl => x } - def typeFullNameToNode(cpg: Cpg, x: String): Option[Type] = - nodesWithFullName(cpg, x).collectFirst { case x: Type => x } + def typeFullNameToNode(cpg: Cpg, x: String): Option[Type] = + nodesWithFullName(cpg, x).collectFirst { case x: Type => x } - def methodFullNameToNode(cpg: Cpg, x: String): Option[Method] = - nodesWithFullName(cpg, x).collectFirst { case x: Method => x } + def methodFullNameToNode(cpg: Cpg, x: String): Option[Method] = + nodesWithFullName(cpg, x).collectFirst { case x: Method => x } - def namespaceBlockFullNameToNode(cpg: Cpg, x: String): Option[NamespaceBlock] = - nodesWithFullName(cpg, x).collectFirst { case x: NamespaceBlock => x } + def namespaceBlockFullNameToNode(cpg: Cpg, x: String): Option[NamespaceBlock] = + nodesWithFullName(cpg, x).collectFirst { case x: NamespaceBlock => x } - def nodesWithFullName(cpg: Cpg, x: String): mutable.Seq[NodeRef[_ <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + def nodesWithFullName(cpg: Cpg, x: String): mutable.Seq[NodeRef[? <: NodeDb]] = + cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala - /** For all nodes `n` with a label in `srcLabels`, determine the value of `n.\$dstFullNameKey`, use that to lookup the - * destination node in `dstNodeMap`, and create an edge of type `edgeType` between `n` and the destination node. - */ - def linkToSingle( - cpg: Cpg, - srcLabels: List[String], - dstNodeLabel: String, - edgeType: String, - dstNodeMap: String => Option[StoredNode], - dstFullNameKey: String, - dstGraph: DiffGraphBuilder, - dstNotExistsHandler: Option[(StoredNode, String) => Unit] - ): Unit = { - var loggedDeprecationWarning = false - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels: _*).foreach { srcNode => - // If the source node does not have any outgoing edges of this type - // This check is just required for backward compatibility - if (srcNode.outE(edgeType).isEmpty) { - val key = new PropertyKey[String](dstFullNameKey) - srcNode - .propertyOption(key) - .filter { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - srcNode.propertyDefaultValue(dstFullNameKey) != dereferenceDstFullName - } - .ifPresent { dstFullName => - // for `UNKNOWN` this is not always set, so we're using an Option here - val srcStoredNode = srcNode.asInstanceOf[StoredNode] - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match { - case Some(dstNode) => - dstGraph.addEdge(srcStoredNode, dstNode, edgeType) - case None if dstNodeMap(dstFullName).isDefined => - dstGraph.addEdge(srcStoredNode, dstNodeMap(dstFullName).get, edgeType) - case None if dstNotExistsHandler.isDefined => - dstNotExistsHandler.get(srcStoredNode, dereferenceDstFullName) - case _ => - logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) - } - } - } else { - srcNode.out(edgeType).property(Properties.FULL_NAME).nextOption() match { - case Some(dstFullName) => - dstGraph.setNodeProperty( - srcNode.asInstanceOf[StoredNode], - dstFullNameKey, - dereference.dereferenceTypeFullName(dstFullName) - ) - case None => logger.debug(s"Missing outgoing edge of type $edgeType from node $srcNode") + /** For all nodes `n` with a label in `srcLabels`, determine the value of `n.\$dstFullNameKey`, + * use that to lookup the destination node in `dstNodeMap`, and create an edge of type + * `edgeType` between `n` and the destination node. + */ + def linkToSingle( + cpg: Cpg, + srcLabels: List[String], + dstNodeLabel: String, + edgeType: String, + dstNodeMap: String => Option[StoredNode], + dstFullNameKey: String, + dstGraph: DiffGraphBuilder, + dstNotExistsHandler: Option[(StoredNode, String) => Unit] + ): Unit = + var loggedDeprecationWarning = false + val dereference = Dereference(cpg) + cpg.graph.nodes(srcLabels*).foreach { srcNode => + // If the source node does not have any outgoing edges of this type + // This check is just required for backward compatibility + if srcNode.outE(edgeType).isEmpty then + val key = new PropertyKey[String](dstFullNameKey) + srcNode + .propertyOption(key) + .filter { dstFullName => + val dereferenceDstFullName = + dereference.dereferenceTypeFullName(dstFullName) + srcNode.propertyDefaultValue(dstFullNameKey) != dereferenceDstFullName + } + .ifPresent { dstFullName => + // for `UNKNOWN` this is not always set, so we're using an Option here + val srcStoredNode = srcNode.asInstanceOf[StoredNode] + val dereferenceDstFullName = + dereference.dereferenceTypeFullName(dstFullName) + dstNodeMap(dereferenceDstFullName) match + case Some(dstNode) => + dstGraph.addEdge(srcStoredNode, dstNode, edgeType) + case None if dstNodeMap(dstFullName).isDefined => + dstGraph.addEdge( + srcStoredNode, + dstNodeMap(dstFullName).get, + edgeType + ) + case None if dstNotExistsHandler.isDefined => + dstNotExistsHandler.get(srcStoredNode, dereferenceDstFullName) + case _ => + logFailedDstLookup( + edgeType, + srcNode.label, + srcNode.id.toString, + dstNodeLabel, + dereferenceDstFullName + ) + } + else + srcNode.out(edgeType).property(Properties.FULL_NAME).nextOption() match + case Some(dstFullName) => + dstGraph.setNodeProperty( + srcNode.asInstanceOf[StoredNode], + dstFullNameKey, + dereference.dereferenceTypeFullName(dstFullName) + ) + case None => + logger.debug(s"Missing outgoing edge of type $edgeType from node $srcNode") + if !loggedDeprecationWarning then + logger.debug( + s"Using deprecated CPG format with already existing $edgeType edge between" + + s" a source node of type $srcLabels and a $dstNodeLabel node." + ) + loggedDeprecationWarning = true } - if (!loggedDeprecationWarning) { - logger.debug( - s"Using deprecated CPG format with already existing $edgeType edge between" + - s" a source node of type $srcLabels and a $dstNodeLabel node." - ) - loggedDeprecationWarning = true - } - } - } - } + end linkToSingle - def linkToMultiple[SRC_NODE_TYPE <: StoredNode]( - cpg: Cpg, - srcLabels: List[String], - dstNodeLabel: String, - edgeType: String, - dstNodeMap: String => Option[StoredNode], - getDstFullNames: SRC_NODE_TYPE => Iterable[String], - dstFullNameKey: String, - dstGraph: DiffGraphBuilder - ): Unit = { - var loggedDeprecationWarning = false - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels: _*).asScala.cast[SRC_NODE_TYPE].foreach { srcNode => - if (!srcNode.outE(edgeType).hasNext) { - getDstFullNames(srcNode).foreach { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match { - case Some(dstNode) => - dstGraph.addEdge(srcNode, dstNode, edgeType) - case None if dstNodeMap(dstFullName).isDefined => - dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) - case None => - logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) - } - } - } else { - val dstFullNames = srcNode.out(edgeType).property(Properties.FULL_NAME).l - dstGraph.setNodeProperty(srcNode, dstFullNameKey, dstFullNames.map(dereference.dereferenceTypeFullName)) - if (!loggedDeprecationWarning) { - logger.debug( - s"Using deprecated CPG format with already existing $edgeType edge between" + - s" a source node of type $srcLabels and a $dstNodeLabel node." - ) - loggedDeprecationWarning = true + def linkToMultiple[SRC_NODE_TYPE <: StoredNode]( + cpg: Cpg, + srcLabels: List[String], + dstNodeLabel: String, + edgeType: String, + dstNodeMap: String => Option[StoredNode], + getDstFullNames: SRC_NODE_TYPE => Iterable[String], + dstFullNameKey: String, + dstGraph: DiffGraphBuilder + ): Unit = + var loggedDeprecationWarning = false + val dereference = Dereference(cpg) + cpg.graph.nodes(srcLabels*).asScala.cast[SRC_NODE_TYPE].foreach { srcNode => + if !srcNode.outE(edgeType).hasNext then + getDstFullNames(srcNode).foreach { dstFullName => + val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) + dstNodeMap(dereferenceDstFullName) match + case Some(dstNode) => + dstGraph.addEdge(srcNode, dstNode, edgeType) + case None if dstNodeMap(dstFullName).isDefined => + dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) + case None => + logFailedDstLookup( + edgeType, + srcNode.label, + srcNode.id.toString, + dstNodeLabel, + dereferenceDstFullName + ) + } + else + val dstFullNames = srcNode.out(edgeType).property(Properties.FULL_NAME).l + dstGraph.setNodeProperty( + srcNode, + dstFullNameKey, + dstFullNames.map(dereference.dereferenceTypeFullName) + ) + if !loggedDeprecationWarning then + logger.debug( + s"Using deprecated CPG format with already existing $edgeType edge between" + + s" a source node of type $srcLabels and a $dstNodeLabel node." + ) + loggedDeprecationWarning = true } - } - } - } - - @inline - protected def logFailedDstLookup( - edgeType: String, - srcNodeType: String, - srcNodeId: String, - dstNodeType: String, - dstFullName: String - ): Unit = { - logger.debug( - "Could not create edge. Destination lookup failed. " + - s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcNodeId=$srcNodeId, " + - s"dstNodeType=$dstNodeType, dstFullName=$dstFullName" - ) - } + end linkToMultiple - @inline - protected def logFailedSrcLookup( - edgeType: String, - srcNodeType: String, - srcFullName: String, - dstNodeType: String, - dstNodeId: String - ): Unit = { - logger.debug( - "Could not create edge. Source lookup failed. " + - s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcFullName=$srcFullName, " + - s"dstNodeType=$dstNodeType, dstNodeId=$dstNodeId" - ) - } + @inline + protected def logFailedDstLookup( + edgeType: String, + srcNodeType: String, + srcNodeId: String, + dstNodeType: String, + dstFullName: String + ): Unit = + logger.debug( + "Could not create edge. Destination lookup failed. " + + s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcNodeId=$srcNodeId, " + + s"dstNodeType=$dstNodeType, dstFullName=$dstFullName" + ) -} + @inline + protected def logFailedSrcLookup( + edgeType: String, + srcNodeType: String, + srcFullName: String, + dstNodeType: String, + dstNodeId: String + ): Unit = + logger.debug( + "Could not create edge. Source lookup failed. " + + s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcFullName=$srcFullName, " + + s"dstNodeType=$dstNodeType, dstNodeId=$dstNodeId" + ) +end LinkingUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala index 19d21078..582c9f70 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala @@ -1,20 +1,15 @@ package io.appthreat.x2cpg.utils -object ListUtils { - extension [T](list: List[T]) { - - /** Return each element in the list up to and including the first element which satisfies the given predicate, or - * return the empty list if no matching element is found, e.g. - * - * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 3) => List(1, 2, 3) - * - * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 5) => Nil - */ - def takeUntil(predicate: T => Boolean): List[T] = { - list.indexWhere(predicate) match { - case index if index >= 0 => list.take(index + 1) - case _ => Nil - } - } - } -} +object ListUtils: + extension [T](list: List[T]) + /** Return each element in the list up to and including the first element which satisfies + * the given predicate, or return the empty list if no matching element is found, e.g. + * + * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 3) => List(1, 2, 3) + * + * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 5) => Nil + */ + def takeUntil(predicate: T => Boolean): List[T] = + list.indexWhere(predicate) match + case index if index >= 0 => list.take(index + 1) + case _ => Nil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala index 036693c1..e57f0b6f 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala @@ -2,179 +2,182 @@ package io.appthreat.x2cpg.utils import io.shiftleft.codepropertygraph.generated.nodes.Call.PropertyDefaults import io.shiftleft.codepropertygraph.generated.nodes.{ - NewAnnotationLiteral, - NewBinding, - NewCall, - NewClosureBinding, - NewDependency, - NewFieldIdentifier, - NewIdentifier, - NewLocal, - NewMethodParameterIn, - NewMethodReturn, - NewModifier + NewAnnotationLiteral, + NewBinding, + NewCall, + NewClosureBinding, + NewDependency, + NewFieldIdentifier, + NewIdentifier, + NewLocal, + NewMethodParameterIn, + NewMethodReturn, + NewModifier } import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies} -/** NodeBuilders helps with node creation and is intended to be used when functions from `x2cpg.AstCreatorBase` are not - * appropriate; for example, in cases in which the node's line and column are _not_ set from the base ASTNode type of a - * specific frontend. +/** NodeBuilders helps with node creation and is intended to be used when functions from + * `x2cpg.AstCreatorBase` are not appropriate; for example, in cases in which the node's line and + * column are _not_ set from the base ASTNode type of a specific frontend. */ -object NodeBuilders { - - private def composeCallSignature(returnType: String, argumentTypes: Iterable[String]): String = { - s"$returnType(${argumentTypes.mkString(",")})" - } - - private def composeMethodFullName(typeDeclFullName: Option[String], name: String, signature: String) = { - val typeDeclPrefix = typeDeclFullName.map(maybeName => s"$maybeName.").getOrElse("") - s"$typeDeclPrefix$name:$signature" - } - - def newAnnotationLiteralNode(name: String): NewAnnotationLiteral = - NewAnnotationLiteral() - .name(name) - .code(name) - - def newBindingNode(name: String, signature: String, methodFullName: String): NewBinding = { - NewBinding() - .name(name) - .methodFullName(methodFullName) - .signature(signature) - } - - def newLocalNode(name: String, typeFullName: String, closureBindingId: Option[String] = None): NewLocal = { - NewLocal() - .code(name) - .name(name) - .typeFullName(typeFullName) - .closureBindingId(closureBindingId) - } - - def newClosureBindingNode( - closureBindingId: String, - originalName: String, - evaluationStrategy: String - ): NewClosureBinding = { - NewClosureBinding() - .closureBindingId(closureBindingId) - .closureOriginalName(originalName) - .evaluationStrategy(evaluationStrategy) - } - - def newCallNode( - methodName: String, - typeDeclFullName: Option[String], - returnTypeFullName: String, - dispatchType: String, - argumentTypes: Iterable[String] = Nil, - code: String = PropertyDefaults.Code, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): NewCall = { - val signature = composeCallSignature(returnTypeFullName, argumentTypes) - val methodFullName = composeMethodFullName(typeDeclFullName, methodName, signature) - NewCall() - .name(methodName) - .methodFullName(methodFullName) - .signature(signature) - .typeFullName(returnTypeFullName) - .dispatchType(dispatchType) - .code(code) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - } - - def newDependencyNode(name: String, groupId: String, version: String): NewDependency = - NewDependency() - .name(name) - .dependencyGroupId(groupId) - .version(version) - - def newFieldIdentifierNode( - name: String, - line: Option[Integer] = None, - column: Option[Integer] = None - ): NewFieldIdentifier = { - NewFieldIdentifier() - .canonicalName(name) - .code(name) - .lineNumber(line) - .columnNumber(column) - } - - def newModifierNode(modifierType: String): NewModifier = NewModifier().modifierType(modifierType) - - def newIdentifierNode(name: String, typeFullName: String, dynamicTypeHints: Seq[String] = Seq()): NewIdentifier = { - newIdentifierNode(name, typeFullName, dynamicTypeHints, None) - } - - def newIdentifierNode( - name: String, - typeFullName: String, - dynamicTypeHints: Seq[String], - line: Option[Integer] - ): NewIdentifier = { - NewIdentifier() - .code(name) - .name(name) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line) - } - - def newOperatorCallNode( - name: String, - code: String, - typeFullName: Option[String] = None, - line: Option[Integer] = None, - column: Option[Integer] = None - ): NewCall = { - NewCall() - .name(name) - .methodFullName(name) - .code(code) - .signature("") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .typeFullName(typeFullName.getOrElse("ANY")) - .lineNumber(line) - .columnNumber(column) - } - - def newThisParameterNode( - name: String = "this", - code: String = "this", - typeFullName: String, - dynamicTypeHintFullName: Seq[String] = Seq.empty, - line: Option[Integer] = None, - column: Option[Integer] = None, - evaluationStrategy: String = EvaluationStrategies.BY_SHARING - ): NewMethodParameterIn = { - NewMethodParameterIn() - .name(name) - .code(code) - .lineNumber(line) - .columnNumber(column) - .dynamicTypeHintFullName(dynamicTypeHintFullName) - .evaluationStrategy(evaluationStrategy) - .typeFullName(typeFullName) - .index(0) - .order(0) - } - - /** Create a method return node - */ - def newMethodReturnNode( - typeFullName: String, - dynamicTypeHintFullName: Option[String] = None, - line: Option[Integer], - column: Option[Integer] - ): NewMethodReturn = - NewMethodReturn() - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHintFullName) - .code("RET") - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .lineNumber(line) - .columnNumber(column) -} +object NodeBuilders: + + private def composeCallSignature(returnType: String, argumentTypes: Iterable[String]): String = + s"$returnType(${argumentTypes.mkString(",")})" + + private def composeMethodFullName( + typeDeclFullName: Option[String], + name: String, + signature: String + ) = + val typeDeclPrefix = typeDeclFullName.map(maybeName => s"$maybeName.").getOrElse("") + s"$typeDeclPrefix$name:$signature" + + def newAnnotationLiteralNode(name: String): NewAnnotationLiteral = + NewAnnotationLiteral() + .name(name) + .code(name) + + def newBindingNode(name: String, signature: String, methodFullName: String): NewBinding = + NewBinding() + .name(name) + .methodFullName(methodFullName) + .signature(signature) + + def newLocalNode( + name: String, + typeFullName: String, + closureBindingId: Option[String] = None + ): NewLocal = + NewLocal() + .code(name) + .name(name) + .typeFullName(typeFullName) + .closureBindingId(closureBindingId) + + def newClosureBindingNode( + closureBindingId: String, + originalName: String, + evaluationStrategy: String + ): NewClosureBinding = + NewClosureBinding() + .closureBindingId(closureBindingId) + .closureOriginalName(originalName) + .evaluationStrategy(evaluationStrategy) + + def newCallNode( + methodName: String, + typeDeclFullName: Option[String], + returnTypeFullName: String, + dispatchType: String, + argumentTypes: Iterable[String] = Nil, + code: String = PropertyDefaults.Code, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): NewCall = + val signature = composeCallSignature(returnTypeFullName, argumentTypes) + val methodFullName = composeMethodFullName(typeDeclFullName, methodName, signature) + NewCall() + .name(methodName) + .methodFullName(methodFullName) + .signature(signature) + .typeFullName(returnTypeFullName) + .dispatchType(dispatchType) + .code(code) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + end newCallNode + + def newDependencyNode(name: String, groupId: String, version: String): NewDependency = + NewDependency() + .name(name) + .dependencyGroupId(groupId) + .version(version) + + def newFieldIdentifierNode( + name: String, + line: Option[Integer] = None, + column: Option[Integer] = None + ): NewFieldIdentifier = + NewFieldIdentifier() + .canonicalName(name) + .code(name) + .lineNumber(line) + .columnNumber(column) + + def newModifierNode(modifierType: String): NewModifier = + NewModifier().modifierType(modifierType) + + def newIdentifierNode( + name: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewIdentifier = + newIdentifierNode(name, typeFullName, dynamicTypeHints, None) + + def newIdentifierNode( + name: String, + typeFullName: String, + dynamicTypeHints: Seq[String], + line: Option[Integer] + ): NewIdentifier = + NewIdentifier() + .code(name) + .name(name) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line) + + def newOperatorCallNode( + name: String, + code: String, + typeFullName: Option[String] = None, + line: Option[Integer] = None, + column: Option[Integer] = None + ): NewCall = + NewCall() + .name(name) + .methodFullName(name) + .code(code) + .signature("") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(typeFullName.getOrElse("ANY")) + .lineNumber(line) + .columnNumber(column) + + def newThisParameterNode( + name: String = "this", + code: String = "this", + typeFullName: String, + dynamicTypeHintFullName: Seq[String] = Seq.empty, + line: Option[Integer] = None, + column: Option[Integer] = None, + evaluationStrategy: String = EvaluationStrategies.BY_SHARING + ): NewMethodParameterIn = + NewMethodParameterIn() + .name(name) + .code(code) + .lineNumber(line) + .columnNumber(column) + .dynamicTypeHintFullName(dynamicTypeHintFullName) + .evaluationStrategy(evaluationStrategy) + .typeFullName(typeFullName) + .index(0) + .order(0) + + /** Create a method return node + */ + def newMethodReturnNode( + typeFullName: String, + dynamicTypeHintFullName: Option[String] = None, + line: Option[Integer], + column: Option[Integer] + ): NewMethodReturn = + NewMethodReturn() + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHintFullName) + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .lineNumber(line) + .columnNumber(column) +end NodeBuilders diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala index 7b1659bf..6546f5c0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala @@ -4,89 +4,82 @@ import org.slf4j.LoggerFactory import scala.collection.concurrent.TrieMap -object Report { - - private val logger = LoggerFactory.getLogger(Report.getClass) - - private type FileName = String - - private type Reports = TrieMap[FileName, ReportEntry] - - private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long) { - def toSeq: Seq[String] = { - val lines = loc.toString - val dur = if (duration == 0) "-" else TimeUtils.pretty(duration) - val wasParsed = if (parsed) "yes" else "no" - val gotCpg = if (cpgGen) "yes" else "no" - Seq(lines, wasParsed, gotCpg, dur) - } - } - -} - -class Report { - - import Report._ - - private val reports: Reports = TrieMap.empty - - private def formatTable(table: Seq[Seq[String]]): String = { - if (table.isEmpty) "" - else { - // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) - val colWidths = - table.transpose.map(_.map(cell => if (cell == null) 0 else cell.length).max + 2) - // Format each row - val rows = table.map( - _.zip(colWidths) - .map { case (item, size) => s" %-${size - 1}s".format(item) } - .mkString("|", "|", "|") - ) - // Formatted separator row, used to separate the header and draw table borders - val separator = colWidths.map("-" * _).mkString("+", "+", "+") - // Put the table together and return - val header = rows.head - val content = rows.tail.take(rows.tail.size - 1) - val footer = rows.tail.last - (separator +: header +: separator +: content :+ separator :+ footer :+ separator) - .mkString("\n") - } - } - - def print(): Unit = { - val rows = reports.toSeq - .sortBy(_._1) - .zipWithIndex - .view - .map { case ((file, sum), index) => - s"${index + 1}" +: file +: sum.toSeq - } - .toSeq - val numOfReports = reports.size - val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) - val footer = Seq( - Seq( - "Total", - "", - s"${reports.map(_._2.loc).sum}", - s"${reports.count(_._2.parsed)}/$numOfReports", - s"${reports.count(_._2.cpgGen)}/$numOfReports", - "" - ) - ) - val table = header ++ rows ++ footer - logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") - } - - def addReportInfo( - fileName: FileName, - loc: Int, - parsed: Boolean = false, - cpgGen: Boolean = false, - duration: Long = 0 - ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) - - def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = - reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) - -} +object Report: + + private val logger = LoggerFactory.getLogger(Report.getClass) + + private type FileName = String + + private type Reports = TrieMap[FileName, ReportEntry] + + private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): + def toSeq: Seq[String] = + val lines = loc.toString + val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) + val wasParsed = if parsed then "yes" else "no" + val gotCpg = if cpgGen then "yes" else "no" + Seq(lines, wasParsed, gotCpg, dur) + +class Report: + + import Report.* + + private val reports: Reports = TrieMap.empty + + private def formatTable(table: Seq[Seq[String]]): String = + if table.isEmpty then "" + else + // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) + val colWidths = + table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) + // Format each row + val rows = table.map( + _.zip(colWidths) + .map { case (item, size) => s" %-${size - 1}s".format(item) } + .mkString("|", "|", "|") + ) + // Formatted separator row, used to separate the header and draw table borders + val separator = colWidths.map("-" * _).mkString("+", "+", "+") + // Put the table together and return + val header = rows.head + val content = rows.tail.take(rows.tail.size - 1) + val footer = rows.tail.last + (separator +: header +: separator +: content :+ separator :+ footer :+ separator) + .mkString("\n") + + def print(): Unit = + val rows = reports.toSeq + .sortBy(_._1) + .zipWithIndex + .view + .map { case ((file, sum), index) => + s"${index + 1}" +: file +: sum.toSeq + } + .toSeq + val numOfReports = reports.size + val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) + val footer = Seq( + Seq( + "Total", + "", + s"${reports.map(_._2.loc).sum}", + s"${reports.count(_._2.parsed)}/$numOfReports", + s"${reports.count(_._2.cpgGen)}/$numOfReports", + "" + ) + ) + val table = header ++ rows ++ footer + logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") + end print + + def addReportInfo( + fileName: FileName, + loc: Int, + parsed: Boolean = false, + cpgGen: Boolean = false, + duration: Long = 0 + ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) + + def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = + reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) +end Report diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala index ccaab017..2551d443 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala @@ -1,7 +1,5 @@ package io.appthreat.x2cpg.utils -implicit class StringUtils(str: String) { - def isAllUpperCase: Boolean = { - str.forall(c => c.isUpper || !c.isLetter) - } -} +implicit class StringUtils(str: String): + def isAllUpperCase: Boolean = + str.forall(c => c.isUpper || !c.isLetter) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala index ed9f33df..86922831 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala @@ -7,59 +7,53 @@ import scala.concurrent.duration.* import scala.language.postfixOps import scala.util.Try -object TimeUtils { - - /** Measures elapsed time for executing a block in nanoseconds */ - def time[R](block: => R): (R, Long) = { - val t0 = System.nanoTime() - val result = block - val t1 = System.nanoTime() - val elapsed = t1 - t0 - (result, elapsed) - } - - /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ - def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) - - def runWithTimeout[T](timeoutMs: Long)(f: => T): Try[T] = { - Try(Await.result(Future(f), timeoutMs milliseconds)) - } - - private def pretty(duration: Duration): String = - duration match { - case d: FiniteDuration => - val nanos = d.toNanos - val unit = chooseUnit(nanos) - val value = nanos.toDouble / NANOSECONDS.convert(1, unit) - - s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) - - case Duration.MinusInf => s"-∞ (minus infinity)" - case Duration.Inf => s"∞ (infinity)" - case _ => "undefined" - } - - private def chooseUnit(nanos: Long): TimeUnit = { - val d = nanos.nanos - - if (d.toDays > 0) DAYS - else if (d.toHours > 0) HOURS - else if (d.toMinutes > 0) MINUTES - else if (d.toSeconds > 0) SECONDS - else if (d.toMillis > 0) MILLISECONDS - else if (d.toMicros > 0) MICROSECONDS - else NANOSECONDS - } - - private def abbreviate(unit: TimeUnit): String = - unit match { - case NANOSECONDS => "ns" - case MICROSECONDS => "μs" - case MILLISECONDS => "ms" - case SECONDS => "s" - case MINUTES => "min" - case HOURS => "h" - case DAYS => "d" - } - -} +object TimeUtils: + + /** Measures elapsed time for executing a block in nanoseconds */ + def time[R](block: => R): (R, Long) = + val t0 = System.nanoTime() + val result = block + val t1 = System.nanoTime() + val elapsed = t1 - t0 + (result, elapsed) + + /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ + def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) + + def runWithTimeout[T](timeoutMs: Long)(f: => T): Try[T] = + Try(Await.result(Future(f), timeoutMs milliseconds)) + + private def pretty(duration: Duration): String = + duration match + case d: FiniteDuration => + val nanos = d.toNanos + val unit = chooseUnit(nanos) + val value = nanos.toDouble / NANOSECONDS.convert(1, unit) + + s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) + + case Duration.MinusInf => s"-∞ (minus infinity)" + case Duration.Inf => s"∞ (infinity)" + case _ => "undefined" + + private def chooseUnit(nanos: Long): TimeUnit = + val d = nanos.nanos + + if d.toDays > 0 then DAYS + else if d.toHours > 0 then HOURS + else if d.toMinutes > 0 then MINUTES + else if d.toSeconds > 0 then SECONDS + else if d.toMillis > 0 then MILLISECONDS + else if d.toMicros > 0 then MICROSECONDS + else NANOSECONDS + + private def abbreviate(unit: TimeUnit): String = + unit match + case NANOSECONDS => "ns" + case MICROSECONDS => "μs" + case MILLISECONDS => "ms" + case SECONDS => "s" + case MINUTES => "min" + case HOURS => "h" + case DAYS => "d" +end TimeUtils diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala index 049f7f8a..ac8ada38 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala @@ -8,122 +8,119 @@ import org.slf4j.LoggerFactory import java.nio.file.Path import scala.util.{Failure, Success} -object GradleConfigKeys extends Enumeration { - type GradleConfigKey = Value - val ProjectName, ConfigurationName = Value -} +object GradleConfigKeys extends Enumeration: + type GradleConfigKey = Value + val ProjectName, ConfigurationName = Value case class DependencyResolverParams( forMaven: Map[String, String] = Map(), forGradle: Map[GradleConfigKey, String] = Map() ) -object DependencyResolver { - private val logger = LoggerFactory.getLogger(getClass) - private val defaultGradleProjectName = "app" - private val defaultGradleConfigurationName = "compileClasspath" - private val MaxSearchDepth: Int = 4 - - def getCoordinates( - projectDir: Path, - params: DependencyResolverParams = new DependencyResolverParams - ): Option[collection.Seq[String]] = { - val coordinates = findSupportedBuildFiles(projectDir).flatMap { buildFile => - if (isMavenBuildFile(buildFile)) - // TODO: implement - None - else if (isGradleBuildFile(buildFile)) - getCoordinatesForGradleProject(buildFile.getParent, defaultGradleConfigurationName) - else { - logger.debug(s"Found unsupported build file $buildFile") - Nil - } - }.flatten - - Option.when(coordinates.nonEmpty)(coordinates) - } - - private def getCoordinatesForGradleProject( - projectDir: Path, - configuration: String - ): Option[collection.Seq[String]] = { - val lines = ExternalCommand.run(s"gradle dependencies --configuration $configuration", projectDir.toString) match { - case Success(lines) => lines - case Failure(exception) => - logger.debug( - s"Could not retrieve dependencies for Gradle project at path `$projectDir`\n" + - exception.getMessage - ) - Seq() - } - - val coordinates = MavenCoordinates.fromGradleOutput(lines) - logger.debug("Got {} Maven coordinates", coordinates.size) - Some(coordinates) - } - - def getDependencies( - projectDir: Path, - params: DependencyResolverParams = new DependencyResolverParams - ): Option[collection.Seq[String]] = { - val dependencies = findSupportedBuildFiles(projectDir).flatMap { buildFile => - if (isMavenBuildFile(buildFile)) { - MavenDependencies.get(buildFile.getParent) - } else if (isGradleBuildFile(buildFile)) { - getDepsForGradleProject(params, buildFile.getParent) - } else { - logger.debug(s"Found unsupported build file $buildFile") - Nil - } - }.flatten - - Option.when(dependencies.nonEmpty)(dependencies) - } - - private def getDepsForGradleProject( - params: DependencyResolverParams, - projectDir: Path - ): Option[collection.Seq[String]] = { - logger.debug("resolving Gradle dependencies at {}", projectDir) - val gradleProjectName = params.forGradle.getOrElse(GradleConfigKeys.ProjectName, defaultGradleProjectName) - val gradleConfiguration = - params.forGradle.getOrElse(GradleConfigKeys.ConfigurationName, defaultGradleConfigurationName) - GradleDependencies.get(projectDir, gradleProjectName, gradleConfiguration) match { - case Some(deps) => Some(deps) - case None => - logger.debug(s"Could not download Gradle dependencies for project at path `$projectDir`") - None - } - } - - private def isGradleBuildFile(file: File): Boolean = { - val pathString = file.pathAsString - pathString.endsWith(".gradle") || pathString.endsWith(".gradle.kts") - } - - private def isMavenBuildFile(file: File): Boolean = { - file.pathAsString.endsWith("pom.xml") - } - - private def findSupportedBuildFiles(currentDir: File, depth: Int = 0): List[Path] = { - if (depth >= MaxSearchDepth) { - logger.debug("findSupportedBuildFiles reached max depth before finding build files") - Nil - } else { - val (childDirectories, childFiles) = currentDir.children.partition(_.isDirectory) - // Only fetch dependencies once for projects with both a build.gradle and a pom.xml file - val childFileList = childFiles.toList - childFileList - .find(isGradleBuildFile) - .orElse(childFileList.find(isMavenBuildFile)) match { - case Some(buildFile) => buildFile.path :: Nil - - case None if childDirectories.isEmpty => Nil - - case None => - childDirectories.flatMap { dir => - findSupportedBuildFiles(dir, depth + 1) - }.toList - } - } - } -} +object DependencyResolver: + private val logger = LoggerFactory.getLogger(getClass) + private val defaultGradleProjectName = "app" + private val defaultGradleConfigurationName = "compileClasspath" + private val MaxSearchDepth: Int = 4 + + def getCoordinates( + projectDir: Path, + params: DependencyResolverParams = new DependencyResolverParams + ): Option[collection.Seq[String]] = + val coordinates = findSupportedBuildFiles(projectDir).flatMap { buildFile => + if isMavenBuildFile(buildFile) then + // TODO: implement + None + else if isGradleBuildFile(buildFile) then + getCoordinatesForGradleProject(buildFile.getParent, defaultGradleConfigurationName) + else + logger.debug(s"Found unsupported build file $buildFile") + Nil + }.flatten + + Option.when(coordinates.nonEmpty)(coordinates) + + private def getCoordinatesForGradleProject( + projectDir: Path, + configuration: String + ): Option[collection.Seq[String]] = + val lines = ExternalCommand.run( + s"gradle dependencies --configuration $configuration", + projectDir.toString + ) match + case Success(lines) => lines + case Failure(exception) => + logger.debug( + s"Could not retrieve dependencies for Gradle project at path `$projectDir`\n" + + exception.getMessage + ) + Seq() + + val coordinates = MavenCoordinates.fromGradleOutput(lines) + logger.debug("Got {} Maven coordinates", coordinates.size) + Some(coordinates) + end getCoordinatesForGradleProject + + def getDependencies( + projectDir: Path, + params: DependencyResolverParams = new DependencyResolverParams + ): Option[collection.Seq[String]] = + val dependencies = findSupportedBuildFiles(projectDir).flatMap { buildFile => + if isMavenBuildFile(buildFile) then + MavenDependencies.get(buildFile.getParent) + else if isGradleBuildFile(buildFile) then + getDepsForGradleProject(params, buildFile.getParent) + else + logger.debug(s"Found unsupported build file $buildFile") + Nil + }.flatten + + Option.when(dependencies.nonEmpty)(dependencies) + + private def getDepsForGradleProject( + params: DependencyResolverParams, + projectDir: Path + ): Option[collection.Seq[String]] = + logger.debug("resolving Gradle dependencies at {}", projectDir) + val gradleProjectName = + params.forGradle.getOrElse(GradleConfigKeys.ProjectName, defaultGradleProjectName) + val gradleConfiguration = + params.forGradle.getOrElse( + GradleConfigKeys.ConfigurationName, + defaultGradleConfigurationName + ) + GradleDependencies.get(projectDir, gradleProjectName, gradleConfiguration) match + case Some(deps) => Some(deps) + case None => + logger.debug( + s"Could not download Gradle dependencies for project at path `$projectDir`" + ) + None + end getDepsForGradleProject + + private def isGradleBuildFile(file: File): Boolean = + val pathString = file.pathAsString + pathString.endsWith(".gradle") || pathString.endsWith(".gradle.kts") + + private def isMavenBuildFile(file: File): Boolean = + file.pathAsString.endsWith("pom.xml") + + private def findSupportedBuildFiles(currentDir: File, depth: Int = 0): List[Path] = + if depth >= MaxSearchDepth then + logger.debug("findSupportedBuildFiles reached max depth before finding build files") + Nil + else + val (childDirectories, childFiles) = currentDir.children.partition(_.isDirectory) + // Only fetch dependencies once for projects with both a build.gradle and a pom.xml file + val childFileList = childFiles.toList + childFileList + .find(isGradleBuildFile) + .orElse(childFileList.find(isMavenBuildFile)) match + case Some(buildFile) => buildFile.path :: Nil + + case None if childDirectories.isEmpty => Nil + + case None => + childDirectories.flatMap { dir => + findSupportedBuildFiles(dir, depth + 1) + }.toList +end DependencyResolver diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala index 0af3619f..f6cf1062 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala @@ -1,6 +1,6 @@ package io.appthreat.x2cpg.utils.dependency -import better.files._ +import better.files.* import org.gradle.tooling.{GradleConnector, ProjectConnection} import org.gradle.tooling.model.GradleProject import org.gradle.tooling.model.build.BuildEnvironment @@ -8,49 +8,49 @@ import org.slf4j.LoggerFactory import java.io.ByteArrayOutputStream import java.nio.file.{Files, Path} -import java.io.{File => JFile} +import java.io.{File as JFile} import java.util.stream.Collectors -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.{Failure, Random, Success, Try, Using} -case class GradleProjectInfo(gradleVersion: String, tasks: Seq[String], hasAndroidSubproject: Boolean = false) { - def gradleVersionMajorMinor(): (Int, Int) = { - def isValidPart(part: String) = part.forall(Character.isDigit) - val parts = gradleVersion.split('.') - if (parts.length == 1 && isValidPart(parts(0))) { - (parts(0).toInt, 0) - } else if (parts.length >= 2 && isValidPart(parts(0)) && isValidPart(parts(1))) { - (parts(0).toInt, parts(1).toInt) - } else { - (-1, -1) - } - } -} +case class GradleProjectInfo( + gradleVersion: String, + tasks: Seq[String], + hasAndroidSubproject: Boolean = false +): + def gradleVersionMajorMinor(): (Int, Int) = + def isValidPart(part: String) = part.forall(Character.isDigit) + val parts = gradleVersion.split('.') + if parts.length == 1 && isValidPart(parts(0)) then + (parts(0).toInt, 0) + else if parts.length >= 2 && isValidPart(parts(0)) && isValidPart(parts(1)) then + (parts(0).toInt, parts(1).toInt) + else + (-1, -1) -object Constants { - val aarFileExtension = "aar" - val gradleAndroidPropertyPrefix = "android." - val gradlePropertiesTaskName = "properties" - val jarInsideAarFileName = "classes.jar" -} +object Constants: + val aarFileExtension = "aar" + val gradleAndroidPropertyPrefix = "android." + val gradlePropertiesTaskName = "properties" + val jarInsideAarFileName = "classes.jar" case class GradleDepsInitScript(contents: String, taskName: String, destinationDir: Path) -object GradleDependencies { - private val logger = LoggerFactory.getLogger(getClass) - private val initScriptPrefix = "x2cpg.init.gradle" - private val taskNamePrefix = "x2cpgCopyDeps" - private val tempDirPrefix = "x2cpgDependencies" +object GradleDependencies: + private val logger = LoggerFactory.getLogger(getClass) + private val initScriptPrefix = "x2cpg.init.gradle" + private val taskNamePrefix = "x2cpgCopyDeps" + private val tempDirPrefix = "x2cpgDependencies" - // works with Gradle 5.1+ because the script makes use of `task.register`: - // https://docs.gradle.org/current/userguide/task_configuration_avoidance.html - private def gradle5OrLaterAndroidInitScript( - taskName: String, - destination: String, - gradleProjectName: String, - gradleConfigurationName: String - ): String = { - s""" + // works with Gradle 5.1+ because the script makes use of `task.register`: + // https://docs.gradle.org/current/userguide/task_configuration_avoidance.html + private def gradle5OrLaterAndroidInitScript( + taskName: String, + destination: String, + gradleProjectName: String, + gradleConfigurationName: String + ): String = + s""" |allprojects { | afterEvaluate { project -> | def taskName = "$taskName" @@ -88,12 +88,11 @@ object GradleDependencies { | } |} |""".stripMargin - } - // this init script _should_ work with Gradle 4-8, but has not been tested thoroughly - // TODO: add test cases for older Gradle versions - private def gradle5OrLaterInitScript(taskName: String, destination: String): String = { - s""" + // this init script _should_ work with Gradle 4-8, but has not been tested thoroughly + // TODO: add test cases for older Gradle versions + private def gradle5OrLaterInitScript(taskName: String, destination: String): String = + s""" |allprojects { | apply plugin: 'java' | task $taskName(type: Copy) { @@ -102,209 +101,234 @@ object GradleDependencies { | } |} |""".stripMargin - } - private def makeInitScript( - destinationDir: Path, - forAndroid: Boolean, - gradleProjectName: String, - gradleConfigurationName: String - ): GradleDepsInitScript = { - val taskName = taskNamePrefix + "_" + (Random.alphanumeric take 8).toList.mkString - val content = - if (forAndroid) { - gradle5OrLaterAndroidInitScript(taskName, destinationDir.toString, gradleProjectName, gradleConfigurationName) - } else { - gradle5OrLaterInitScript(taskName, destinationDir.toString) - } - GradleDepsInitScript(content, taskName, destinationDir) - } + private def makeInitScript( + destinationDir: Path, + forAndroid: Boolean, + gradleProjectName: String, + gradleConfigurationName: String + ): GradleDepsInitScript = + val taskName = taskNamePrefix + "_" + (Random.alphanumeric take 8).toList.mkString + val content = + if forAndroid then + gradle5OrLaterAndroidInitScript( + taskName, + destinationDir.toString, + gradleProjectName, + gradleConfigurationName + ) + else + gradle5OrLaterInitScript(taskName, destinationDir.toString) + GradleDepsInitScript(content, taskName, destinationDir) - private[dependency] def makeConnection(projectDir: JFile): ProjectConnection = { - GradleConnector.newConnector().forProjectDirectory(projectDir).connect() - } + private[dependency] def makeConnection(projectDir: JFile): ProjectConnection = + GradleConnector.newConnector().forProjectDirectory(projectDir).connect() - private def getGradleProjectInfo(projectDir: Path, projectName: String): Option[GradleProjectInfo] = { - Try(makeConnection(projectDir.toFile)) match { - case Success(gradleConnection) => - Using.resource(gradleConnection) { connection => - try { - val buildEnv = connection.getModel[BuildEnvironment](classOf[BuildEnvironment]) - val project = connection.getModel[GradleProject](classOf[GradleProject]) - val hasAndroidPrefixGradleProperty = - runGradleTask(connection, Constants.gradlePropertiesTaskName) match { - case Some(out) => - out.split('\n').exists(_.startsWith(Constants.gradleAndroidPropertyPrefix)) - case None => false - } - val info = GradleProjectInfo( - buildEnv.getGradle.getGradleVersion, - project.getTasks.asScala.map(_.getName).toSeq, - hasAndroidPrefixGradleProperty - ) - if (hasAndroidPrefixGradleProperty) { - val validProjectNames = List(project.getName) ++ project.getChildren.getAll.asScala.map(_.getName) - logger.debug(s"Found Gradle projects: ${validProjectNames.mkString(",")}") - if (!validProjectNames.contains(projectName)) { - val validProjectNamesStr = validProjectNames.mkString(",") + private def getGradleProjectInfo( + projectDir: Path, + projectName: String + ): Option[GradleProjectInfo] = + Try(makeConnection(projectDir.toFile)) match + case Success(gradleConnection) => + Using.resource(gradleConnection) { connection => + try + val buildEnv = + connection.getModel[BuildEnvironment](classOf[BuildEnvironment]) + val project = connection.getModel[GradleProject](classOf[GradleProject]) + val hasAndroidPrefixGradleProperty = + runGradleTask(connection, Constants.gradlePropertiesTaskName) match + case Some(out) => + out.split('\n').exists( + _.startsWith(Constants.gradleAndroidPropertyPrefix) + ) + case None => false + val info = GradleProjectInfo( + buildEnv.getGradle.getGradleVersion, + project.getTasks.asScala.map(_.getName).toSeq, + hasAndroidPrefixGradleProperty + ) + if hasAndroidPrefixGradleProperty then + val validProjectNames = List( + project.getName + ) ++ project.getChildren.getAll.asScala.map(_.getName) + logger.debug( + s"Found Gradle projects: ${validProjectNames.mkString(",")}" + ) + if !validProjectNames.contains(projectName) then + val validProjectNamesStr = validProjectNames.mkString(",") + logger.debug( + s"The provided Gradle project name `$projectName` is is not part of the valid project names: `$validProjectNamesStr`" + ) + None + else + Some(info) + else + Some(info) + catch + case t: Throwable => + logger.debug( + s"Caught exception while trying use Gradle connection: ${t.getMessage}" + ) + logger.debug(s"Full exception: ", t) + None + } + case Failure(t) => logger.debug( - s"The provided Gradle project name `$projectName` is is not part of the valid project names: `$validProjectNamesStr`" + s"Caught exception while trying fetch Gradle project information: ${t.getMessage}" ) + logger.debug(s"Full exception: ", t) None - } else { - Some(info) - } - } else { - Some(info) - } - } catch { - case t: Throwable => - logger.debug(s"Caught exception while trying use Gradle connection: ${t.getMessage}") - logger.debug(s"Full exception: ", t) - None - } - } - case Failure(t) => - logger.debug(s"Caught exception while trying fetch Gradle project information: ${t.getMessage}") - logger.debug(s"Full exception: ", t) - None - } - } - private def runGradleTask(connection: ProjectConnection, taskName: String): Option[String] = { - Using.resource(new ByteArrayOutputStream()) { out => - Try( - connection - .newBuild() - .forTasks(taskName) - .setStandardOutput(out) - .run() - ) match { - case Success(_) => Some(out.toString) - case Failure(ex) => - logger.debug(s"Caught exception while executing Gradle task named `$taskName`:", ex.getMessage) - logger.debug(s"Full exception: ", ex) - None - } - } - } + private def runGradleTask(connection: ProjectConnection, taskName: String): Option[String] = + Using.resource(new ByteArrayOutputStream()) { out => + Try( + connection + .newBuild() + .forTasks(taskName) + .setStandardOutput(out) + .run() + ) match + case Success(_) => Some(out.toString) + case Failure(ex) => + logger.debug( + s"Caught exception while executing Gradle task named `$taskName`:", + ex.getMessage + ) + logger.debug(s"Full exception: ", ex) + None + } - private def runGradleTask( - connection: ProjectConnection, - initScript: GradleDepsInitScript, - initScriptPath: String - ): Option[collection.Seq[String]] = { - Using.resources(new ByteArrayOutputStream, new ByteArrayOutputStream) { case (stdoutStream, stderrStream) => - logger.debug(s"Executing gradle task '${initScript.taskName}'...") - Try( - connection - .newBuild() - .forTasks(initScript.taskName) - .withArguments("--init-script", initScriptPath) - .setStandardOutput(stdoutStream) - .setStandardError(stderrStream) - .run() - ) match { - case Success(_) => - val result = - Files - .list(initScript.destinationDir) - .collect(Collectors.toList[Path]) - .asScala - .map(_.toAbsolutePath.toString) - logger.debug(s"Resolved `${result.size}` dependency files.") - Some(result) - case Failure(ex) => - logger.debug(s"Caught exception while executing Gradle task: ${ex.getMessage}") - logger.debug(s"Gradle task execution stdout: \n$stdoutStream") - logger.debug(s"Gradle task execution stderr: \n$stderrStream") - None - } - } - } + private def runGradleTask( + connection: ProjectConnection, + initScript: GradleDepsInitScript, + initScriptPath: String + ): Option[collection.Seq[String]] = + Using.resources(new ByteArrayOutputStream, new ByteArrayOutputStream) { + case (stdoutStream, stderrStream) => + logger.debug(s"Executing gradle task '${initScript.taskName}'...") + Try( + connection + .newBuild() + .forTasks(initScript.taskName) + .withArguments("--init-script", initScriptPath) + .setStandardOutput(stdoutStream) + .setStandardError(stderrStream) + .run() + ) match + case Success(_) => + val result = + Files + .list(initScript.destinationDir) + .collect(Collectors.toList[Path]) + .asScala + .map(_.toAbsolutePath.toString) + logger.debug(s"Resolved `${result.size}` dependency files.") + Some(result) + case Failure(ex) => + logger.debug( + s"Caught exception while executing Gradle task: ${ex.getMessage}" + ) + logger.debug(s"Gradle task execution stdout: \n$stdoutStream") + logger.debug(s"Gradle task execution stderr: \n$stderrStream") + None + end match + } - private def extractClassesJarFromAar(aar: File): Option[Path] = { - val newPath = aar.path.toString.replaceFirst(Constants.aarFileExtension + "$", "jar") - val aarUnzipDirSuffix = ".unzipped" - val outDir = File(aar.path.toString + aarUnzipDirSuffix) - aar.unzipTo(outDir, _.getName == Constants.jarInsideAarFileName) - val outFile = File(newPath) - val classesJarEntries = - outDir.listRecursively - .filter(_.path.getFileName.toString == Constants.jarInsideAarFileName) - .toList - if (classesJarEntries.size != 1) { - logger.debug(s"Found aar file without `classes.jar` inside at path ${aar.path}") - outDir.delete() - None - } else { - val classesJar = classesJarEntries.head - logger.trace(s"Copying `classes.jar` for aar at `${aar.path.toString}` into `$newPath`") - classesJar.copyTo(outFile) - outDir.delete() - aar.delete() - Some(outFile.path) - } - } + private def extractClassesJarFromAar(aar: File): Option[Path] = + val newPath = aar.path.toString.replaceFirst(Constants.aarFileExtension + "$", "jar") + val aarUnzipDirSuffix = ".unzipped" + val outDir = File(aar.path.toString + aarUnzipDirSuffix) + aar.unzipTo(outDir, _.getName == Constants.jarInsideAarFileName) + val outFile = File(newPath) + val classesJarEntries = + outDir.listRecursively + .filter(_.path.getFileName.toString == Constants.jarInsideAarFileName) + .toList + if classesJarEntries.size != 1 then + logger.debug(s"Found aar file without `classes.jar` inside at path ${aar.path}") + outDir.delete() + None + else + val classesJar = classesJarEntries.head + logger.trace(s"Copying `classes.jar` for aar at `${aar.path.toString}` into `$newPath`") + classesJar.copyTo(outFile) + outDir.delete() + aar.delete() + Some(outFile.path) + end extractClassesJarFromAar - // fetch the gradle project information first, then invoke a newly-defined gradle task to copy the necessary jars into - // a destination directory. - private[dependency] def get( - projectDir: Path, - projectName: String, - configurationName: String - ): Option[collection.Seq[String]] = { - logger.debug(s"Fetching Gradle project information at path `$projectDir` with project name `$projectName`.") - getGradleProjectInfo(projectDir, projectName) match { - case Some(projectInfo) if projectInfo.gradleVersionMajorMinor()._1 < 5 => - logger.debug(s"Unsupported Gradle version `${projectInfo.gradleVersion}`") - None - case Some(projectInfo) => - Try(File.newTemporaryDirectory(tempDirPrefix).deleteOnExit()) match { - case Success(destinationDir) => - Try(File.newTemporaryFile(initScriptPrefix).deleteOnExit()) match { - case Success(initScriptFile) => - val initScript = - makeInitScript(destinationDir.path, projectInfo.hasAndroidSubproject, projectName, configurationName) - initScriptFile.write(initScript.contents) + // fetch the gradle project information first, then invoke a newly-defined gradle task to copy the necessary jars into + // a destination directory. + private[dependency] def get( + projectDir: Path, + projectName: String, + configurationName: String + ): Option[collection.Seq[String]] = + logger.debug( + s"Fetching Gradle project information at path `$projectDir` with project name `$projectName`." + ) + getGradleProjectInfo(projectDir, projectName) match + case Some(projectInfo) if projectInfo.gradleVersionMajorMinor()._1 < 5 => + logger.debug(s"Unsupported Gradle version `${projectInfo.gradleVersion}`") + None + case Some(projectInfo) => + Try(File.newTemporaryDirectory(tempDirPrefix).deleteOnExit()) match + case Success(destinationDir) => + Try(File.newTemporaryFile(initScriptPrefix).deleteOnExit()) match + case Success(initScriptFile) => + val initScript = + makeInitScript( + destinationDir.path, + projectInfo.hasAndroidSubproject, + projectName, + configurationName + ) + initScriptFile.write(initScript.contents) - logger.debug( - s"Downloading dependencies for configuration `$configurationName` of project `$projectName` at `$projectDir` into `$destinationDir`..." - ) - Try(makeConnection(projectDir.toFile)) match { - case Success(connection) => - Using.resource(connection) { c => - runGradleTask(c, initScript, initScriptFile.pathAsString) match { - case Some(deps) => - Some(deps.map { d => - if (!d.endsWith(Constants.aarFileExtension)) d - else - extractClassesJarFromAar(File(d)) match { - case Some(path) => path.toString - case None => d - } - }) - case None => None - } - } - case Failure(ex) => - logger.debug(s"Caught exception while trying to establish a Gradle connection: ${ex.getMessage}") - logger.debug(s"Full exception: ", ex) - None - } - case Failure(ex) => - logger.debug(s"Could not create temporary file for Gradle init script: ${ex.getMessage}") - logger.debug(s"Full exception: ", ex) + logger.debug( + s"Downloading dependencies for configuration `$configurationName` of project `$projectName` at `$projectDir` into `$destinationDir`..." + ) + Try(makeConnection(projectDir.toFile)) match + case Success(connection) => + Using.resource(connection) { c => + runGradleTask( + c, + initScript, + initScriptFile.pathAsString + ) match + case Some(deps) => + Some(deps.map { d => + if !d.endsWith(Constants.aarFileExtension) + then d + else + extractClassesJarFromAar(File(d)) match + case Some(path) => path.toString + case None => d + }) + case None => None + } + case Failure(ex) => + logger.debug( + s"Caught exception while trying to establish a Gradle connection: ${ex.getMessage}" + ) + logger.debug(s"Full exception: ", ex) + None + end match + case Failure(ex) => + logger.debug( + s"Could not create temporary file for Gradle init script: ${ex.getMessage}" + ) + logger.debug(s"Full exception: ", ex) + None + case Failure(ex) => + logger.debug( + s"Could not create temporary directory for saving dependency files: ${ex.getMessage}" + ) + logger.debug("Full exception: ", ex) + None + case None => + logger.debug("Could not fetch Gradle project information") None - } - case Failure(ex) => - logger.debug(s"Could not create temporary directory for saving dependency files: ${ex.getMessage}") - logger.debug("Full exception: ", ex) - None - } - case None => - logger.debug("Could not fetch Gradle project information") - None - } - } -} + end match + end get +end GradleDependencies diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala index 411520f7..39f44dbe 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala @@ -6,9 +6,9 @@ import org.slf4j.LoggerFactory import java.nio.file.Path import scala.util.{Failure, Success} -object MavenCoordinates { - private[dependency] def fromGradleOutput(lines: Seq[String]): Seq[String] = { - /* +object MavenCoordinates: + private[dependency] def fromGradleOutput(lines: Seq[String]): Seq[String] = + /* on the following regex, for the following input: ``` | | +--- org.springframework.boot:spring-boot-starter-logging:3.0.5 @@ -39,20 +39,19 @@ object MavenCoordinates { org.slf4j:jul-to-slf4j:2.0.7 ^g1 ------------------^^g2 ^ ``` - */ - val pattern = """^[| ]*[+\\]\s*[-]*\s*([^:]+:[^:]+:)([^\s]+)(\s+->\s+)?([^\s]+)?""".r - lines - .flatMap { l => - pattern.findFirstMatchIn(l) match { - case Some(m) => - if (Option(m.group(4)).isEmpty) - Some(m.group(1) + m.group(2)) - else - Some(m.group(1) + m.group(4)) - case _ => None - } - } - .distinct - .sorted - } -} + */ + val pattern = """^[| ]*[+\\]\s*[-]*\s*([^:]+:[^:]+:)([^\s]+)(\s+->\s+)?([^\s]+)?""".r + lines + .flatMap { l => + pattern.findFirstMatchIn(l) match + case Some(m) => + if Option(m.group(4)).isEmpty then + Some(m.group(1) + m.group(2)) + else + Some(m.group(1) + m.group(4)) + case _ => None + } + .distinct + .sorted + end fromGradleOutput +end MavenCoordinates diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala index 297bdedc..2decb863 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala @@ -6,39 +6,38 @@ import org.slf4j.LoggerFactory import java.nio.file.Path import scala.util.{Failure, Success} -object MavenDependencies { - private val logger = LoggerFactory.getLogger(getClass) +object MavenDependencies: + private val logger = LoggerFactory.getLogger(getClass) - private[dependency] def get(projectDir: Path): Option[collection.Seq[String]] = { - // we can't use -Dmdep.outputFile because that keeps overwriting its own output for each sub-project it's running for - val lines = ExternalCommand.run( - s"mvn -B dependency:build-classpath -DincludeScope=compile -Dorg.slf4j.simpleLogger.defaultLogLevel=info -Dorg.slf4j.simpleLogger.logFile=System.out", - projectDir.toString - ) match { - case Success(lines) => lines - case Failure(exception) => - logger.debug( - s"Retrieval of compile class path via maven return with error.\n" + - "The compile class path may be missing or partial.\n" + - "Results will suffer from poor type information.\n\n" + - exception.getMessage - ) - // exception message is the program output - and we still want to look for potential partial results - exception.getMessage.linesIterator.toSeq - } + private[dependency] def get(projectDir: Path): Option[collection.Seq[String]] = + // we can't use -Dmdep.outputFile because that keeps overwriting its own output for each sub-project it's running for + val lines = ExternalCommand.run( + s"mvn -B dependency:build-classpath -DincludeScope=compile -Dorg.slf4j.simpleLogger.defaultLogLevel=info -Dorg.slf4j.simpleLogger.logFile=System.out", + projectDir.toString + ) match + case Success(lines) => lines + case Failure(exception) => + logger.debug( + s"Retrieval of compile class path via maven return with error.\n" + + "The compile class path may be missing or partial.\n" + + "Results will suffer from poor type information.\n\n" + + exception.getMessage + ) + // exception message is the program output - and we still want to look for potential partial results + exception.getMessage.linesIterator.toSeq - var classPathNext = false - val deps = lines - .flatMap { line => - val isClassPathNow = classPathNext - classPathNext = line.endsWith("Dependencies classpath:") + var classPathNext = false + val deps = lines + .flatMap { line => + val isClassPathNow = classPathNext + classPathNext = line.endsWith("Dependencies classpath:") - if (isClassPathNow) line.split(':') else Array.empty[String] - } - .distinct - .toList + if isClassPathNow then line.split(':') else Array.empty[String] + } + .distinct + .toList - logger.debug("got {} Maven dependencies", deps.size) - Some(deps) - } -} + logger.debug("got {} Maven dependencies", deps.size) + Some(deps) + end get +end MavenDependencies diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala b/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala index b2d116e4..7c32104f 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala @@ -23,204 +23,222 @@ import scala.collection.mutable import scala.jdk.CollectionConverters.IteratorHasAsScala import scala.util.Using -object ChenExport { - - case class Config( - cpgFileName: String = "cpg.bin", - outDir: String = "out", - repr: Representation.Value = Representation.Cpg14, - format: Format.Value = Format.Dot - ) - - /** Choose from either a subset of the graph, or the entire graph (all). - */ - object Representation extends Enumeration { - val Ast, Cfg, Ddg, Cdg, Pdg, Cpg14, Cpg, All = Value - - lazy val byNameLowercase: Map[String, Value] = - values.map { value => - value.toString.toLowerCase -> value - }.toMap - - def withNameIgnoreCase(s: String): Value = - byNameLowercase.getOrElse(s, throw new NoSuchElementException(s"No value found for '$s'")) - - } - object Format extends Enumeration { - val Dot, Neo4jCsv, Graphml, Graphson = Value - - lazy val byNameLowercase: Map[String, Value] = - values.map { value => - value.toString.toLowerCase -> value - }.toMap - - def withNameIgnoreCase(s: String): Value = - byNameLowercase.getOrElse(s, throw new NoSuchElementException(s"No value found for '$s'")) - } - - def main(args: Array[String]): Unit = { - parseConfig(args).foreach { config => - val outDir = config.outDir - exitIfInvalid(outDir, config.cpgFileName) - mkdir(File(outDir)) - - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => - exportCpg(cpg, config.repr, config.format, Paths.get(outDir).toAbsolutePath) - } - } - } - - private def parseConfig(args: Array[String]): Option[Config] = { - new scopt.OptionParser[Config]("chen-export") { - head("Dump intermediate graph representations (or entire graph) of code in a given export format") - help("help") - arg[String]("cpg") - .text("input CPG file name - defaults to `cpg.bin`") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - opt[String]('o', "out") - .text("output directory - will be created and must not yet exist") - .action((x, c) => c.copy(outDir = x)) - opt[String]("repr") - .text( - s"representation to extract: [${Representation.values.toSeq.map(_.toString.toLowerCase).sorted.mkString("|")}] - defaults to `${Representation.Cpg14}`" - ) - .action((x, c) => c.copy(repr = Representation.withNameIgnoreCase(x))) - opt[String]("format") - .action((x, c) => c.copy(format = Format.withNameIgnoreCase(x))) - .text( - s"export format, one of [${Format.values.toSeq.map(_.toString.toLowerCase).sorted.mkString("|")}] - defaults to `${Format.Dot}`" - ) - }.parse(args, Config()) - } - - def exportCpg(cpg: Cpg, representation: Representation.Value, format: Format.Value, outDir: Path): Unit = { - implicit val semantics: Semantics = DefaultSemantics() - if (semantics.elements.isEmpty) { - System.err.println("Warning: semantics are empty.") - } - - CpgBasedTool.addDataFlowOverlayIfNonExistent(cpg) - val context = new LayerCreatorContext(cpg) - - format match { - case Format.Dot if representation == Representation.All || representation == Representation.Cpg => - exportWithOdbFormat(cpg, representation, outDir, DotExporter) - case Format.Dot => - exportDot(representation, outDir, context) - case Format.Neo4jCsv => - exportWithOdbFormat(cpg, representation, outDir, Neo4jCsvExporter) - case Format.Graphml => - exportWithOdbFormat(cpg, representation, outDir, GraphMLExporter) - case Format.Graphson => - exportWithOdbFormat(cpg, representation, outDir, GraphSONExporter) - case other => - throw new NotImplementedError(s"repr=$representation not yet supported for format=$format") - } - } - - private def exportDot(repr: Representation.Value, outDir: Path, context: LayerCreatorContext): Unit = { - val outDirStr = outDir.toString - import Representation._ - repr match { - case Ast => new DumpAst(AstDumpOptions(outDirStr)).create(context) - case Cfg => new DumpCfg(CfgDumpOptions(outDirStr)).create(context) - case Ddg => new DumpDdg(DdgDumpOptions(outDirStr)).create(context) - case Cdg => new DumpCdg(CdgDumpOptions(outDirStr)).create(context) - case Pdg => new DumpPdg(PdgDumpOptions(outDirStr)).create(context) - case Cpg14 => new DumpCpg14(Cpg14DumpOptions(outDirStr)).create(context) - case other => throw new NotImplementedError(s"repr=$repr not yet supported for this format") - } - } - - private def exportWithOdbFormat( - cpg: Cpg, - repr: Representation.Value, - outDir: Path, - exporter: overflowdb.formats.Exporter - ): Unit = { - val ExportResult(nodeCount, edgeCount, _, additionalInfo) = repr match { - case Representation.All => - exporter.runExport(cpg.graph, outDir) - case Representation.Cpg => - val windowsFilenameDeduplicationHelper = mutable.Set.empty[String] - splitByMethod(cpg).iterator - .map { case subGraph @ MethodSubGraph(methodName, methodFilename, nodes) => - val relativeFilename = sanitizedFileName( - methodName, - methodFilename, - exporter.defaultFileExtension, - windowsFilenameDeduplicationHelper +object ChenExport: + + case class Config( + cpgFileName: String = "cpg.bin", + outDir: String = "out", + repr: Representation.Value = Representation.Cpg14, + format: Format.Value = Format.Dot + ) + + /** Choose from either a subset of the graph, or the entire graph (all). + */ + object Representation extends Enumeration: + val Ast, Cfg, Ddg, Cdg, Pdg, Cpg14, Cpg, All = Value + + lazy val byNameLowercase: Map[String, Value] = + values.map { value => + value.toString.toLowerCase -> value + }.toMap + + def withNameIgnoreCase(s: String): Value = + byNameLowercase.getOrElse( + s, + throw new NoSuchElementException(s"No value found for '$s'") + ) + + object Format extends Enumeration: + val Dot, Neo4jCsv, Graphml, Graphson = Value + + lazy val byNameLowercase: Map[String, Value] = + values.map { value => + value.toString.toLowerCase -> value + }.toMap + + def withNameIgnoreCase(s: String): Value = + byNameLowercase.getOrElse( + s, + throw new NoSuchElementException(s"No value found for '$s'") ) - val outFileName = outDir.resolve(relativeFilename) - exporter.runExport(nodes, subGraph.edges, outFileName) - } - .reduce(plus) - case other => - throw new NotImplementedError(s"repr=$repr not yet supported for this format") - } - - println(s"exported $nodeCount nodes, $edgeCount edges into $outDir") - additionalInfo.foreach(println) - } - - /** for each method in the cpg: recursively traverse all AST edges to get the subgraph of nodes within this method add - * the method and this subgraph to the export add all edges between all of these nodes to the export - */ - private def splitByMethod(cpg: Cpg): IterableOnce[MethodSubGraph] = { - cpg.method.map { method => - MethodSubGraph(methodName = method.name, methodFilename = method.filename, nodes = method.ast.toSet) - } - } - - /** @param windowsFilenameDeduplicationHelper - * utility map to ensure we don't override output files for identical method names - */ - private def sanitizedFileName( - methodName: String, - methodFilename: String, - fileExtension: String, - windowsFilenameDeduplicationHelper: mutable.Set[String] - ): String = { - val sanitizedMethodName = methodName.replaceAll("[^a-zA-Z0-9-_\\.]", "_") - val sanitizedFilename = - if (scala.util.Properties.isWin) { - // windows has some quirks in it's file system, e.g. we need to ensure paths aren't too long - so we're using a - // different strategy to sanitize windows file names: first occurrence of a given method uses the method name - // any methods with the same name afterwards get a `_` suffix - if (windowsFilenameDeduplicationHelper.contains(sanitizedMethodName)) { - sanitizedFileName(s"${methodName}_", methodFilename, fileExtension, windowsFilenameDeduplicationHelper) - } else { - windowsFilenameDeduplicationHelper.add(sanitizedMethodName) - sanitizedMethodName + + def main(args: Array[String]): Unit = + parseConfig(args).foreach { config => + val outDir = config.outDir + exitIfInvalid(outDir, config.cpgFileName) + mkdir(File(outDir)) + + Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + exportCpg(cpg, config.repr, config.format, Paths.get(outDir).toAbsolutePath) + } } - } else { // non-windows - // handle leading `/` to ensure we're not writing outside of the output directory - val sanitizedPath = - if (methodFilename.startsWith("/")) s"_root_/$methodFilename" - else methodFilename - s"$sanitizedPath/$sanitizedMethodName" - } - - s"$sanitizedFilename.$fileExtension" - } - - private def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = { - ExportResult( - nodeCount = resultA.nodeCount + resultB.nodeCount, - edgeCount = resultA.edgeCount + resultB.edgeCount, - files = resultA.files ++ resultB.files, - additionalInfo = resultA.additionalInfo - ) - } - - case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[Node]) { - def edges: Set[Edge] = { - for { - node <- nodes - edge <- node.bothE.asScala - if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) - } yield edge - } - } -} + + private def parseConfig(args: Array[String]): Option[Config] = + new scopt.OptionParser[Config]("chen-export"): + head( + "Dump intermediate graph representations (or entire graph) of code in a given export format" + ) + help("help") + arg[String]("cpg") + .text("input CPG file name - defaults to `cpg.bin`") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + opt[String]('o', "out") + .text("output directory - will be created and must not yet exist") + .action((x, c) => c.copy(outDir = x)) + opt[String]("repr") + .text( + s"representation to extract: [${Representation.values.toSeq.map( + _.toString.toLowerCase + ).sorted.mkString("|")}] - defaults to `${Representation.Cpg14}`" + ) + .action((x, c) => c.copy(repr = Representation.withNameIgnoreCase(x))) + opt[String]("format") + .action((x, c) => c.copy(format = Format.withNameIgnoreCase(x))) + .text( + s"export format, one of [${Format.values.toSeq.map(_.toString.toLowerCase).sorted.mkString("|")}] - defaults to `${Format.Dot}`" + ) + .parse(args, Config()) + + def exportCpg( + cpg: Cpg, + representation: Representation.Value, + format: Format.Value, + outDir: Path + ): Unit = + implicit val semantics: Semantics = DefaultSemantics() + if semantics.elements.isEmpty then + System.err.println("Warning: semantics are empty.") + + CpgBasedTool.addDataFlowOverlayIfNonExistent(cpg) + val context = new LayerCreatorContext(cpg) + + format match + case Format.Dot + if representation == Representation.All || representation == Representation.Cpg => + exportWithOdbFormat(cpg, representation, outDir, DotExporter) + case Format.Dot => + exportDot(representation, outDir, context) + case Format.Neo4jCsv => + exportWithOdbFormat(cpg, representation, outDir, Neo4jCsvExporter) + case Format.Graphml => + exportWithOdbFormat(cpg, representation, outDir, GraphMLExporter) + case Format.Graphson => + exportWithOdbFormat(cpg, representation, outDir, GraphSONExporter) + case other => + throw new NotImplementedError( + s"repr=$representation not yet supported for format=$format" + ) + end exportCpg + + private def exportDot( + repr: Representation.Value, + outDir: Path, + context: LayerCreatorContext + ): Unit = + val outDirStr = outDir.toString + import Representation.* + repr match + case Ast => new DumpAst(AstDumpOptions(outDirStr)).create(context) + case Cfg => new DumpCfg(CfgDumpOptions(outDirStr)).create(context) + case Ddg => new DumpDdg(DdgDumpOptions(outDirStr)).create(context) + case Cdg => new DumpCdg(CdgDumpOptions(outDirStr)).create(context) + case Pdg => new DumpPdg(PdgDumpOptions(outDirStr)).create(context) + case Cpg14 => new DumpCpg14(Cpg14DumpOptions(outDirStr)).create(context) + case other => + throw new NotImplementedError(s"repr=$repr not yet supported for this format") + + private def exportWithOdbFormat( + cpg: Cpg, + repr: Representation.Value, + outDir: Path, + exporter: overflowdb.formats.Exporter + ): Unit = + val ExportResult(nodeCount, edgeCount, _, additionalInfo) = repr match + case Representation.All => + exporter.runExport(cpg.graph, outDir) + case Representation.Cpg => + val windowsFilenameDeduplicationHelper = mutable.Set.empty[String] + splitByMethod(cpg).iterator + .map { case subGraph @ MethodSubGraph(methodName, methodFilename, nodes) => + val relativeFilename = sanitizedFileName( + methodName, + methodFilename, + exporter.defaultFileExtension, + windowsFilenameDeduplicationHelper + ) + val outFileName = outDir.resolve(relativeFilename) + exporter.runExport(nodes, subGraph.edges, outFileName) + } + .reduce(plus) + case other => + throw new NotImplementedError(s"repr=$repr not yet supported for this format") + + println(s"exported $nodeCount nodes, $edgeCount edges into $outDir") + additionalInfo.foreach(println) + end exportWithOdbFormat + + /** for each method in the cpg: recursively traverse all AST edges to get the subgraph of nodes + * within this method add the method and this subgraph to the export add all edges between all + * of these nodes to the export + */ + private def splitByMethod(cpg: Cpg): IterableOnce[MethodSubGraph] = + cpg.method.map { method => + MethodSubGraph( + methodName = method.name, + methodFilename = method.filename, + nodes = method.ast.toSet + ) + } + + /** @param windowsFilenameDeduplicationHelper + * utility map to ensure we don't override output files for identical method names + */ + private def sanitizedFileName( + methodName: String, + methodFilename: String, + fileExtension: String, + windowsFilenameDeduplicationHelper: mutable.Set[String] + ): String = + val sanitizedMethodName = methodName.replaceAll("[^a-zA-Z0-9-_\\.]", "_") + val sanitizedFilename = + if scala.util.Properties.isWin then + // windows has some quirks in it's file system, e.g. we need to ensure paths aren't too long - so we're using a + // different strategy to sanitize windows file names: first occurrence of a given method uses the method name + // any methods with the same name afterwards get a `_` suffix + if windowsFilenameDeduplicationHelper.contains(sanitizedMethodName) then + sanitizedFileName( + s"${methodName}_", + methodFilename, + fileExtension, + windowsFilenameDeduplicationHelper + ) + else + windowsFilenameDeduplicationHelper.add(sanitizedMethodName) + sanitizedMethodName + else // non-windows + // handle leading `/` to ensure we're not writing outside of the output directory + val sanitizedPath = + if methodFilename.startsWith("/") then s"_root_/$methodFilename" + else methodFilename + s"$sanitizedPath/$sanitizedMethodName" + + s"$sanitizedFilename.$fileExtension" + end sanitizedFileName + + private def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = + ExportResult( + nodeCount = resultA.nodeCount + resultB.nodeCount, + edgeCount = resultA.edgeCount + resultB.edgeCount, + files = resultA.files ++ resultB.files, + additionalInfo = resultA.additionalInfo + ) + + case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[Node]): + def edges: Set[Edge] = + for + node <- nodes + edge <- node.bothE.asScala + if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) + yield edge +end ChenExport diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala b/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala index 98a3a24a..2f758fb7 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala @@ -19,91 +19,89 @@ case class FlowConfig( depth: Int = 1 ) -object ChenFlow { - def main(args: Array[String]) = { - parseConfig(args).foreach { config => - def debugOut(msg: String): Unit = { - if (config.verbose) { - print(msg) +object ChenFlow: + def main(args: Array[String]) = + parseConfig(args).foreach { config => + def debugOut(msg: String): Unit = + if config.verbose then + print(msg) + + debugOut("Loading graph... ") + val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName) + debugOut("[DONE]\n") + + implicit val resolver: ICallResolver = NoResolve + val sources = params(cpg, config.srcRegex, config.srcParam) + val sinks = params(cpg, config.dstRegex, config.dstParam) + + debugOut(s"Number of sources: ${sources.size}\n") + debugOut(s"Number of sinks: ${sinks.size}\n") + + implicit val semantics: Semantics = DefaultSemantics() + val engineConfig = EngineConfig(config.depth) + debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n") + implicit val context: EngineContext = EngineContext(semantics, engineConfig) + + debugOut("Determining flows...") + sinks.foreach { s => + List(s).reachableByFlows(sources.iterator).p.foreach(println) + } + debugOut("[DONE]") + + debugOut("Closing graph... ") + cpg.close() + debugOut("[DONE]\n") } - } - - debugOut("Loading graph... ") - val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName) - debugOut("[DONE]\n") - - implicit val resolver: ICallResolver = NoResolve - val sources = params(cpg, config.srcRegex, config.srcParam) - val sinks = params(cpg, config.dstRegex, config.dstParam) - - debugOut(s"Number of sources: ${sources.size}\n") - debugOut(s"Number of sinks: ${sinks.size}\n") - - implicit val semantics: Semantics = DefaultSemantics() - val engineConfig = EngineConfig(config.depth) - debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n") - implicit val context: EngineContext = EngineContext(semantics, engineConfig) - - debugOut("Determining flows...") - sinks.foreach { s => - List(s).reachableByFlows(sources.iterator).p.foreach(println) - } - debugOut("[DONE]") - - debugOut("Closing graph... ") - cpg.close() - debugOut("[DONE]\n") - } - } - - private def parseConfig(args: Array[String]): Option[FlowConfig] = { - new scopt.OptionParser[FlowConfig]("chen-flow") { - head("Find flows") - help("help") - - arg[String]("src") - .text("source regex") - .action((x, c) => c.copy(srcRegex = x)) - - arg[String]("dst") - .text("destination regex") - .action((x, c) => c.copy(dstRegex = x)) - - arg[String]("cpg") - .text("CPG file name ('cpg.bin' by default)") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - - opt[Int]("src-param") - .text("Source parameter") - .optional() - .action((x, c) => c.copy(dstParam = Some(x))) - - opt[Int]("dst-param") - .text("Destination parameter") - .optional() - .action((x, c) => c.copy(dstParam = Some(x))) - - opt[Int]("depth") - .text("Analysis depth (number of calls to expand)") - .optional() - .action((x, c) => c.copy(depth = x)) - - opt[Unit]("verbose") - .text("Print debug information") - .optional() - .action((_, c) => c.copy(verbose = true)) - } - }.parse(args, FlowConfig()) - - private def params(cpg: Cpg, methodNameRegex: String, paramIndex: Option[Int]): List[MethodParameterIn] = { - cpg - .method(methodNameRegex) - .parameter - .filter { p => - paramIndex.isEmpty || paramIndex.contains(p.order) - } - .l - } - -} + + private def parseConfig(args: Array[String]): Option[FlowConfig] = { + new scopt.OptionParser[FlowConfig]("chen-flow"): + head("Find flows") + help("help") + + arg[String]("src") + .text("source regex") + .action((x, c) => c.copy(srcRegex = x)) + + arg[String]("dst") + .text("destination regex") + .action((x, c) => c.copy(dstRegex = x)) + + arg[String]("cpg") + .text("CPG file name ('cpg.bin' by default)") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + + opt[Int]("src-param") + .text("Source parameter") + .optional() + .action((x, c) => c.copy(dstParam = Some(x))) + + opt[Int]("dst-param") + .text("Destination parameter") + .optional() + .action((x, c) => c.copy(dstParam = Some(x))) + + opt[Int]("depth") + .text("Analysis depth (number of calls to expand)") + .optional() + .action((x, c) => c.copy(depth = x)) + + opt[Unit]("verbose") + .text("Print debug information") + .optional() + .action((_, c) => c.copy(verbose = true)) + }.parse(args, FlowConfig()) + + private def params( + cpg: Cpg, + methodNameRegex: String, + paramIndex: Option[Int] + ): List[MethodParameterIn] = + cpg + .method(methodNameRegex) + .parameter + .filter { p => + paramIndex.isEmpty || paramIndex.contains(p.order) + } + .l +end ChenFlow diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala b/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala index 7c85ee72..d51f3934 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala @@ -7,187 +7,178 @@ import CpgBasedTool.newCpgCreatedString import io.shiftleft.codepropertygraph.generated.Languages import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* import scala.util.{Failure, Success, Try} -object ChenParse { - // Special string used to separate joern-parse opts from frontend-specific opts - val ArgsDelimitor = "--frontend-args" - val DefaultCpgOutFile = "cpg.bin" - var generator: CpgGenerator = _ - - def main(args: Array[String]): Unit = { - run(args) match { - case Success(msg) => - println(msg) - case Failure(err) => - err.printStackTrace() - System.exit(1) - } - } - - val optionParser = new scopt.OptionParser[ParserConfig]("chen-parse") { - arg[String]("input") - .optional() - .text("source file or directory containing source files") - .action((x, c) => c.copy(inputPath = x)) - - opt[String]('o', "output") - .text("output filename") - .action((x, c) => c.copy(outputCpgFile = x)) - - opt[String]("language") - .text("source language") - .action((x, c) => c.copy(language = x)) - - opt[Unit]("list-languages") - .text("list available language options") - .action((_, c) => c.copy(listLanguages = true)) - - opt[String]("namespaces") - .text("namespaces to include: comma separated string") - .action((x, c) => c.copy(namespaces = x.split(",").map(_.trim).toSeq)) - - note("Overlay application stage") - - opt[Unit]("nooverlays") - .text("do not apply default overlays") - .action((_, c) => c.copy(enhance = false)) - opt[Unit]("overlaysonly") - .text("Only apply default overlays") - .action((_, c) => c.copy(enhanceOnly = true)) - - opt[Int]("max-num-def") - .text("Maximum number of definitions in per-method data flow calculation") - .action((x, c) => c.copy(maxNumDef = x)) - - note("Misc") - help("help").text("display this help message") - - note(s"Args specified after the $ArgsDelimitor separator will be passed to the front-end verbatim") - } - - private def run(args: Array[String]): Try[String] = { - val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args) - val installConfig = new InstallConfig() - - parseConfig(parserArgs).flatMap { config => - if (config.listLanguages) - Try(buildLanguageList()) - else - run(config, frontendArgs, installConfig) - } - } - - def run( - config: ParserConfig, - frontendArgs: List[String] = List.empty, - installConfig: InstallConfig = InstallConfig() - ): Try[String] = { - for { - _ <- checkInputPath(config) - language <- getLanguage(config) - _ <- generateCpg(installConfig, frontendArgs, config, language) - _ <- applyDefaultOverlays(config) - } yield newCpgCreatedString(config.outputCpgFile) - } - - private def checkInputPath(config: ParserConfig): Try[Unit] = { - Try { - if (config.inputPath == "") { - println(optionParser.usage) - throw new AssertionError(s"Input path required") - } else if (!File(config.inputPath).exists) - throw new AssertionError(s"Input path does not exist at `${config.inputPath}`, exiting.") - else () - } - } - - private def buildLanguageList(): String = { - val s = new mutable.StringBuilder() - s ++= "Available languages (case insensitive):\n" - s ++= Languages.ALL.asScala.map(lang => s"- ${lang.toLowerCase}").mkString("\n") - s.toString() - } - - private def getLanguage(config: ParserConfig): Try[String] = { - Try { - if (config.language.nonEmpty) { - config.language - } else { - guessLanguage(config.inputPath) - .getOrElse( - throw new AssertionError( - s"Could not guess language from input path ${config.inputPath}. Please specify a language using the --language option." - ) - ) - } - } - } - - private def generateCpg( - installConfig: InstallConfig, - frontendArgs: Seq[String], - config: ParserConfig, - language: String - ): Try[String] = { - if (config.enhanceOnly) { - Success("No generation required") - } else { - println(s"Parsing code at: ${config.inputPath} - language: `$language`") - println("[+] Running language frontend") - Try { - cpgGeneratorForLanguage( - language.toUpperCase, - FrontendConfig(), - installConfig.rootPath.path, - frontendArgs.toList - ).get - }.flatMap { newGenerator => - generator = newGenerator - generator - .generate(config.inputPath, outputPath = config.outputCpgFile) - .recover { case exception => - throw new RuntimeException( - s"Could not generate CPG with language = $language and input = ${config.inputPath}", - exception - ) - } - } - } - } - - private def applyDefaultOverlays(config: ParserConfig): Try[String] = { - Try { - println("[+] Applying default overlays") - if (config.enhance) { - val cpg = DefaultOverlays.create(config.outputCpgFile, config.maxNumDef) - generator.applyPostProcessingPasses(cpg) - cpg.close() - } - "Code property graph generation successful" - } - } - - case class ParserConfig( - inputPath: String = "", - outputCpgFile: String = DefaultCpgOutFile, - namespaces: Seq[String] = Seq.empty, - enhance: Boolean = true, - enhanceOnly: Boolean = false, - language: String = "", - listLanguages: Boolean = false, - maxNumDef: Int = DefaultOverlays.defaultMaxNumberOfDefinitions - ) - - private def parseConfig(parserArgs: Seq[String]): Try[ParserConfig] = { - Try { - optionParser - .parse(parserArgs, ParserConfig()) - .getOrElse( - throw new RuntimeException(s"Error while not parsing command line options: `${parserArgs.mkString(",")}`") +object ChenParse: + // Special string used to separate joern-parse opts from frontend-specific opts + val ArgsDelimitor = "--frontend-args" + val DefaultCpgOutFile = "cpg.bin" + var generator: CpgGenerator = _ + + def main(args: Array[String]): Unit = + run(args) match + case Success(msg) => + println(msg) + case Failure(err) => + err.printStackTrace() + System.exit(1) + + val optionParser = new scopt.OptionParser[ParserConfig]("chen-parse"): + arg[String]("input") + .optional() + .text("source file or directory containing source files") + .action((x, c) => c.copy(inputPath = x)) + + opt[String]('o', "output") + .text("output filename") + .action((x, c) => c.copy(outputCpgFile = x)) + + opt[String]("language") + .text("source language") + .action((x, c) => c.copy(language = x)) + + opt[Unit]("list-languages") + .text("list available language options") + .action((_, c) => c.copy(listLanguages = true)) + + opt[String]("namespaces") + .text("namespaces to include: comma separated string") + .action((x, c) => c.copy(namespaces = x.split(",").map(_.trim).toSeq)) + + note("Overlay application stage") + + opt[Unit]("nooverlays") + .text("do not apply default overlays") + .action((_, c) => c.copy(enhance = false)) + opt[Unit]("overlaysonly") + .text("Only apply default overlays") + .action((_, c) => c.copy(enhanceOnly = true)) + + opt[Int]("max-num-def") + .text("Maximum number of definitions in per-method data flow calculation") + .action((x, c) => c.copy(maxNumDef = x)) + + note("Misc") + help("help").text("display this help message") + + note( + s"Args specified after the $ArgsDelimitor separator will be passed to the front-end verbatim" ) - } - } -} + private def run(args: Array[String]): Try[String] = + val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args) + val installConfig = new InstallConfig() + + parseConfig(parserArgs).flatMap { config => + if config.listLanguages then + Try(buildLanguageList()) + else + run(config, frontendArgs, installConfig) + } + + def run( + config: ParserConfig, + frontendArgs: List[String] = List.empty, + installConfig: InstallConfig = InstallConfig() + ): Try[String] = + for + _ <- checkInputPath(config) + language <- getLanguage(config) + _ <- generateCpg(installConfig, frontendArgs, config, language) + _ <- applyDefaultOverlays(config) + yield newCpgCreatedString(config.outputCpgFile) + + private def checkInputPath(config: ParserConfig): Try[Unit] = + Try { + if config.inputPath == "" then + println(optionParser.usage) + throw new AssertionError(s"Input path required") + else if !File(config.inputPath).exists then + throw new AssertionError( + s"Input path does not exist at `${config.inputPath}`, exiting." + ) + else () + } + + private def buildLanguageList(): String = + val s = new mutable.StringBuilder() + s ++= "Available languages (case insensitive):\n" + s ++= Languages.ALL.asScala.map(lang => s"- ${lang.toLowerCase}").mkString("\n") + s.toString() + + private def getLanguage(config: ParserConfig): Try[String] = + Try { + if config.language.nonEmpty then + config.language + else + guessLanguage(config.inputPath) + .getOrElse( + throw new AssertionError( + s"Could not guess language from input path ${config.inputPath}. Please specify a language using the --language option." + ) + ) + } + + private def generateCpg( + installConfig: InstallConfig, + frontendArgs: Seq[String], + config: ParserConfig, + language: String + ): Try[String] = + if config.enhanceOnly then + Success("No generation required") + else + println(s"Parsing code at: ${config.inputPath} - language: `$language`") + println("[+] Running language frontend") + Try { + cpgGeneratorForLanguage( + language.toUpperCase, + FrontendConfig(), + installConfig.rootPath.path, + frontendArgs.toList + ).get + }.flatMap { newGenerator => + generator = newGenerator + generator + .generate(config.inputPath, outputPath = config.outputCpgFile) + .recover { case exception => + throw new RuntimeException( + s"Could not generate CPG with language = $language and input = ${config.inputPath}", + exception + ) + } + } + + private def applyDefaultOverlays(config: ParserConfig): Try[String] = + Try { + println("[+] Applying default overlays") + if config.enhance then + val cpg = DefaultOverlays.create(config.outputCpgFile, config.maxNumDef) + generator.applyPostProcessingPasses(cpg) + cpg.close() + "Code property graph generation successful" + } + + case class ParserConfig( + inputPath: String = "", + outputCpgFile: String = DefaultCpgOutFile, + namespaces: Seq[String] = Seq.empty, + enhance: Boolean = true, + enhanceOnly: Boolean = false, + language: String = "", + listLanguages: Boolean = false, + maxNumDef: Int = DefaultOverlays.defaultMaxNumberOfDefinitions + ) + + private def parseConfig(parserArgs: Seq[String]): Try[ParserConfig] = + Try { + optionParser + .parse(parserArgs, ParserConfig()) + .getOrElse( + throw new RuntimeException( + s"Error while not parsing command line options: `${parserArgs.mkString(",")}`" + ) + ) + } +end ChenParse diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala b/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala index ba0db3a7..6f13ea7b 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala @@ -14,175 +14,166 @@ import scala.jdk.CollectionConverters.* import scala.util.Using import scala.util.hashing.MurmurHash3 -class BagOfPropertiesForNodes extends EmbeddingGenerator[AstNode, (String, String)] { - override def structureToString(pair: (String, String)): String = pair._1 + ":" + pair._2 - override def extractObjects(cpg: Cpg): Iterator[AstNode] = cpg.graph.V.collect { case x: AstNode => x } - override def enumerateSubStructures(obj: AstNode): List[(String, String)] = { - val relevantFieldTypes = Set(PropertyNames.NAME, PropertyNames.FULL_NAME, PropertyNames.CODE) - val relevantFields = obj - .propertiesMap() - .entrySet() - .asScala - .toList - .filter { e => relevantFieldTypes.contains(e.getKey) } - .sortBy(_.getKey) - .map { e => - (e.getKey, e.getValue.toString) - } - List(("id", obj.id().toString)) ++ relevantFields ++ List(("label", obj.label)) - } - - override def objectToString(node: AstNode): String = node.id().toString - override def hash(label: String): String = label - - override def vectorToString(vector: Map[(String, String), Double]): String = { - val jsonObj = Json.fromFields(vector.keys.map { case (k, v) => (k, Json.fromString(v)) }) - jsonObj.toString - } - -} - -class BagOfAPISymbolsForMethods extends EmbeddingGenerator[Method, AstNode] { - override def extractObjects(cpg: Cpg): Iterator[Method] = cpg.method - override def enumerateSubStructures(method: Method): List[AstNode] = method.ast.l - override def structureToString(astNode: AstNode): String = astNode.code - override def objectToString(method: Method): String = method.fullName -} - -object EmbeddingGenerator { - type SparseVectorWithExplicitFeature[S] = Map[String, (Double, S)] - type SparseVector = Map[Int, Double] -} - -/** Creates an embedding from a code property graph by following three steps: (1) Objects are extracted from the graph, - * each of which is ultimately to be mapped to one vector (2) For each object, enumerate its sub structures. (3) Employ - * feature hashing to associate each sub structure with a dimension. See "Pattern-based Vulnerability Discovery - - * Chapter 3" +class BagOfPropertiesForNodes extends EmbeddingGenerator[AstNode, (String, String)]: + override def structureToString(pair: (String, String)): String = pair._1 + ":" + pair._2 + override def extractObjects(cpg: Cpg): Iterator[AstNode] = + cpg.graph.V.collect { case x: AstNode => x } + override def enumerateSubStructures(obj: AstNode): List[(String, String)] = + val relevantFieldTypes = + Set(PropertyNames.NAME, PropertyNames.FULL_NAME, PropertyNames.CODE) + val relevantFields = obj + .propertiesMap() + .entrySet() + .asScala + .toList + .filter { e => relevantFieldTypes.contains(e.getKey) } + .sortBy(_.getKey) + .map { e => + (e.getKey, e.getValue.toString) + } + List(("id", obj.id().toString)) ++ relevantFields ++ List(("label", obj.label)) + + override def objectToString(node: AstNode): String = node.id().toString + override def hash(label: String): String = label + + override def vectorToString(vector: Map[(String, String), Double]): String = + val jsonObj = Json.fromFields(vector.keys.map { case (k, v) => (k, Json.fromString(v)) }) + jsonObj.toString +end BagOfPropertiesForNodes + +class BagOfAPISymbolsForMethods extends EmbeddingGenerator[Method, AstNode]: + override def extractObjects(cpg: Cpg): Iterator[Method] = cpg.method + override def enumerateSubStructures(method: Method): List[AstNode] = method.ast.l + override def structureToString(astNode: AstNode): String = astNode.code + override def objectToString(method: Method): String = method.fullName + +object EmbeddingGenerator: + type SparseVectorWithExplicitFeature[S] = Map[String, (Double, S)] + type SparseVector = Map[Int, Double] + +/** Creates an embedding from a code property graph by following three steps: (1) Objects are + * extracted from the graph, each of which is ultimately to be mapped to one vector (2) For each + * object, enumerate its sub structures. (3) Employ feature hashing to associate each sub structure + * with a dimension. See "Pattern-based Vulnerability Discovery - Chapter 3" * * T: Object type S: Sub structure type */ -trait EmbeddingGenerator[T, S] { - import EmbeddingGenerator.* - - case class Embedding(data: () => Iterator[(T, SparseVectorWithExplicitFeature[S])]) { - lazy val dimToStructure: Map[String, S] = { - val m = mutable.HashMap[String, S]() - data().foreach { case (_, vector) => - vector.foreach { case (hash, (_, structure)) => - m.put(hash, structure) - } - } - m.toMap - } - - lazy val structureToDim: Map[S, String] = for ((k, v) <- dimToStructure) yield (v, k) +trait EmbeddingGenerator[T, S]: + import EmbeddingGenerator.* - def objects: Iterator[String] = data().map { case (obj, _) => objectToString(obj) } + case class Embedding(data: () => Iterator[(T, SparseVectorWithExplicitFeature[S])]): + lazy val dimToStructure: Map[String, S] = + val m = mutable.HashMap[String, S]() + data().foreach { case (_, vector) => + vector.foreach { case (hash, (_, structure)) => + m.put(hash, structure) + } + } + m.toMap - def vectors: Iterator[Map[S, Double]] = data().map { case (_, vector) => - vector.map { case (_, (v, structure)) => structure -> v } - } + lazy val structureToDim: Map[S, String] = for ((k, v) <- dimToStructure) yield (v, k) - } + def objects: Iterator[String] = data().map { case (obj, _) => objectToString(obj) } - /** Extract a sequence of (object, vector) pairs from a cpg. - */ - def embed(cpg: Cpg): Embedding = { - Embedding({ () => - extractObjects(cpg) - .map { obj => - val substructures = enumerateSubStructures(obj) - obj -> vectorize(substructures) + def vectors: Iterator[Map[S, Double]] = data().map { case (_, vector) => + vector.map { case (_, (v, structure)) => structure -> v } } - }) - } - - private def vectorize(substructures: List[S]): Map[String, (Double, S)] = { - substructures - .groupBy(x => structureToString(x)) - .view - .map { case (_, l) => - val v = l.size - hash(structureToString(l.head)) -> (v.toDouble, l.head) - } - .toMap - } - - def hash(label: String): String = MurmurHash3.stringHash(label).toString - - def structureToString(s: S): String - - /** A function that creates a sequence of objects from a CPG - */ - def extractObjects(cpg: Cpg): Iterator[T] - - /** A function that, for a given object, extracts its sub structures - */ - def enumerateSubStructures(obj: T): List[S] - - def objectToString(t: T): String - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - - def vectorToString(vector: Map[S, Double]): String = defaultToString(vector) - - def defaultToString[M](v: M): String = Serialization.write(v) - -} - -object JoernVectors { - - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - case class Config(cpgFileName: String = "cpg.bin", outDir: String = "out", dimToFeature: Boolean = false) - - def main(args: Array[String]) = { - parseConfig(args).foreach { config => - exitIfInvalid(config.outDir, config.cpgFileName) - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => - val generator = new BagOfPropertiesForNodes() - val embedding = generator.embed(cpg) - println("{") - println("\"objects\":") - traversalToJson(embedding.objects, generator.defaultToString) - if (config.dimToFeature) { - println(",\"dimToFeature\": ") - println(Serialization.write(embedding.dimToStructure)) + /** Extract a sequence of (object, vector) pairs from a cpg. + */ + def embed(cpg: Cpg): Embedding = + Embedding({ () => + extractObjects(cpg) + .map { obj => + val substructures = enumerateSubStructures(obj) + obj -> vectorize(substructures) + } + }) + + private def vectorize(substructures: List[S]): Map[String, (Double, S)] = + substructures + .groupBy(x => structureToString(x)) + .view + .map { case (_, l) => + val v = l.size + hash(structureToString(l.head)) -> (v.toDouble, l.head) + } + .toMap + + def hash(label: String): String = MurmurHash3.stringHash(label).toString + + def structureToString(s: S): String + + /** A function that creates a sequence of objects from a CPG + */ + def extractObjects(cpg: Cpg): Iterator[T] + + /** A function that, for a given object, extracts its sub structures + */ + def enumerateSubStructures(obj: T): List[S] + + def objectToString(t: T): String + + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + + def vectorToString(vector: Map[S, Double]): String = defaultToString(vector) + + def defaultToString[M](v: M): String = Serialization.write(v) +end EmbeddingGenerator + +object JoernVectors: + + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + case class Config( + cpgFileName: String = "cpg.bin", + outDir: String = "out", + dimToFeature: Boolean = false + ) + + def main(args: Array[String]) = + parseConfig(args).foreach { config => + exitIfInvalid(config.outDir, config.cpgFileName) + Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + val generator = new BagOfPropertiesForNodes() + val embedding = generator.embed(cpg) + println("{") + println("\"objects\":") + traversalToJson(embedding.objects, generator.defaultToString) + if config.dimToFeature then + println(",\"dimToFeature\": ") + println(Serialization.write(embedding.dimToStructure)) + println(",\"vectors\":") + traversalToJson(embedding.vectors, generator.vectorToString) + println(",\"edges\":") + traversalToJson( + cpg.graph.edges().map { x => + Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label()) + }, + generator.defaultToString + ) + println("}") + } } - println(",\"vectors\":") - traversalToJson(embedding.vectors, generator.vectorToString) - println(",\"edges\":") - traversalToJson( - cpg.graph.edges().map { x => - Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label()) - }, - generator.defaultToString - ) - println("}") - } - } - } - - private def parseConfig(args: Array[String]): Option[Config] = - new scopt.OptionParser[Config]("chen-vectors") { - head("Extract vector representations of code from CPG") - help("help") - arg[String]("cpg") - .text("input CPG file name - defaults to `cpg.bin`") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - opt[String]('o', "out") - .text("output directory - will be created and must not yet exist") - .action((x, c) => c.copy(outDir = x)) - opt[Unit]("features") - .text("Provide map from dimensions to features") - .action((_, c) => c.copy(dimToFeature = true)) - }.parse(args, Config()) - - private def traversalToJson[X](trav: Iterator[X], vectorToString: X => String): Unit = { - println("[") - trav.nextOption().foreach { vector => print(vectorToString(vector)) } - trav.foreach { vector => print(",\n" + vectorToString(vector)) } - println("]") - } - -} + + private def parseConfig(args: Array[String]): Option[Config] = + new scopt.OptionParser[Config]("chen-vectors"): + head("Extract vector representations of code from CPG") + help("help") + arg[String]("cpg") + .text("input CPG file name - defaults to `cpg.bin`") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + opt[String]('o', "out") + .text("output directory - will be created and must not yet exist") + .action((x, c) => c.copy(outDir = x)) + opt[Unit]("features") + .text("Provide map from dimensions to features") + .action((_, c) => c.copy(dimToFeature = true)) + .parse(args, Config()) + + private def traversalToJson[X](trav: Iterator[X], vectorToString: X => String): Unit = + println("[") + trav.nextOption().foreach { vector => print(vectorToString(vector)) } + trav.foreach { vector => print(",\n" + vectorToString(vector)) } + println("]") +end JoernVectors diff --git a/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala b/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala index c9e7b72e..7563937d 100644 --- a/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala +++ b/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala @@ -5,64 +5,56 @@ import io.appthreat.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlow import io.appthreat.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.cpgloading.CpgLoaderConfig -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.layers.LayerCreatorContext -object CpgBasedTool { - - /** Load code property graph from overflowDB - * - * @param filename - * name of the file that stores the CPG - */ - def loadFromOdb(filename: String): Cpg = { - val odbConfig = overflowdb.Config.withDefaults().withStorageLocation(filename) - val config = CpgLoaderConfig().withOverflowConfig(odbConfig).doNotCreateIndexesOnLoad - io.shiftleft.codepropertygraph.cpgloading.CpgLoader.loadFromOverflowDb(config) - } - - /** Add the data flow layer to the CPG if it does not exist yet. - */ - def addDataFlowOverlayIfNonExistent(cpg: Cpg)(implicit s: Semantics): Unit = { - if (!cpg.metaData.overlays.exists(_ == OssDataFlow.overlayName)) { - System.err.println("CPG does not have dataflow overlay. Calculating.") - val opts = new OssDataFlowOptions() - val context = new LayerCreatorContext(cpg) - new OssDataFlow(opts).run(context) - } - } - - /** Create an informational string for the user that informs of a successfully generated CPG. - */ - def newCpgCreatedString(path: String): String = { - val absolutePath = File(path).path.toAbsolutePath - s"Successfully wrote graph to: $absolutePath\n" + - s"To load the graph, type `joern $absolutePath`" - } - - val ARGS_DELIMITER = "--frontend-args" - - /** Splits arguments at the ARGS_DELIMITER into arguments for the tool and arguments for the language frontend. - */ - def splitArgs(args: Array[String]): (List[String], List[String]) = { - args.indexOf(ARGS_DELIMITER) match { - case -1 => (args.toList, Nil) - case splitIdx => - val (parseOpts, frontendOpts) = args.toList.splitAt(splitIdx) - (parseOpts, frontendOpts.tail) // Take the tail to ignore the delimiter - } - } - - def exitIfInvalid(outDir: String, cpgFileName: String): Unit = { - if (File(outDir).exists) - exitWithError(s"Output directory `$outDir` already exists.") - if (File(cpgFileName).notExists) - exitWithError(s"CPG at $cpgFileName does not exist.") - } - - def exitWithError(msg: String): Unit = { - System.err.println(s"error: $msg") - System.exit(1) - } - -} +object CpgBasedTool: + + /** Load code property graph from overflowDB + * + * @param filename + * name of the file that stores the CPG + */ + def loadFromOdb(filename: String): Cpg = + val odbConfig = overflowdb.Config.withDefaults().withStorageLocation(filename) + val config = CpgLoaderConfig().withOverflowConfig(odbConfig).doNotCreateIndexesOnLoad + io.shiftleft.codepropertygraph.cpgloading.CpgLoader.loadFromOverflowDb(config) + + /** Add the data flow layer to the CPG if it does not exist yet. + */ + def addDataFlowOverlayIfNonExistent(cpg: Cpg)(implicit s: Semantics): Unit = + if !cpg.metaData.overlays.exists(_ == OssDataFlow.overlayName) then + System.err.println("CPG does not have dataflow overlay. Calculating.") + val opts = new OssDataFlowOptions() + val context = new LayerCreatorContext(cpg) + new OssDataFlow(opts).run(context) + + /** Create an informational string for the user that informs of a successfully generated CPG. + */ + def newCpgCreatedString(path: String): String = + val absolutePath = File(path).path.toAbsolutePath + s"Successfully wrote graph to: $absolutePath\n" + + s"To load the graph, type `joern $absolutePath`" + + val ARGS_DELIMITER = "--frontend-args" + + /** Splits arguments at the ARGS_DELIMITER into arguments for the tool and arguments for the + * language frontend. + */ + def splitArgs(args: Array[String]): (List[String], List[String]) = + args.indexOf(ARGS_DELIMITER) match + case -1 => (args.toList, Nil) + case splitIdx => + val (parseOpts, frontendOpts) = args.toList.splitAt(splitIdx) + (parseOpts, frontendOpts.tail) // Take the tail to ignore the delimiter + + def exitIfInvalid(outDir: String, cpgFileName: String): Unit = + if File(outDir).exists then + exitWithError(s"Output directory `$outDir` already exists.") + if File(cpgFileName).notExists then + exitWithError(s"CPG at $cpgFileName does not exist.") + + def exitWithError(msg: String): Unit = + System.err.println(s"error: $msg") + System.exit(1) +end CpgBasedTool diff --git a/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala b/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala index f87c0ff9..5aaab1e7 100644 --- a/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala +++ b/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala @@ -6,23 +6,24 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.semanticcpg.layers.* import io.shiftleft.semanticcpg.layers.LayerCreatorContext -object DefaultOverlays { +object DefaultOverlays: - val DEFAULT_CPG_IN_FILE = "cpg.bin" - val defaultMaxNumberOfDefinitions = 4000 + val DEFAULT_CPG_IN_FILE = "cpg.bin" + val defaultMaxNumberOfDefinitions = 4000 - /** Load the CPG at `storeFilename` and add enhancements, turning the CPG into an SCPG. - * - * @param storeFilename - * the filename of the cpg - */ - def create(storeFilename: String, maxNumberOfDefinitions: Int = defaultMaxNumberOfDefinitions): Cpg = { - val cpg = CpgBasedTool.loadFromOdb(storeFilename) - applyDefaultOverlays(cpg) - val context = new LayerCreatorContext(cpg) - val options = new OssDataFlowOptions(maxNumberOfDefinitions) - new OssDataFlow(options).run(context) - cpg - } - -} + /** Load the CPG at `storeFilename` and add enhancements, turning the CPG into an SCPG. + * + * @param storeFilename + * the filename of the cpg + */ + def create( + storeFilename: String, + maxNumberOfDefinitions: Int = defaultMaxNumberOfDefinitions + ): Cpg = + val cpg = CpgBasedTool.loadFromOdb(storeFilename) + applyDefaultOverlays(cpg) + val context = new LayerCreatorContext(cpg) + val options = new OssDataFlowOptions(maxNumberOfDefinitions) + new OssDataFlow(options).run(context) + cpg +end DefaultOverlays diff --git a/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala b/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala index 7ac69920..41668958 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala @@ -1,6 +1,6 @@ package io.appthreat.chencli.console -import better.files._ +import better.files.* import io.appthreat.console.workspacehandling.{ProjectFile, WorkspaceLoader} import io.appthreat.console.{Console, ConsoleConfig, InstallConfig} import io.appthreat.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} @@ -12,44 +12,39 @@ import java.nio.file.Path object ChenWorkspaceLoader {} -class ChenWorkspaceLoader extends WorkspaceLoader[ChenProject] { - override def createProject(projectFile: ProjectFile, path: Path): ChenProject = { - val project = new ChenProject(projectFile, path) - project.context = EngineContext() - project - } -} +class ChenWorkspaceLoader extends WorkspaceLoader[ChenProject]: + override def createProject(projectFile: ProjectFile, path: Path): ChenProject = + val project = new ChenProject(projectFile, path) + project.context = EngineContext() + project -class ChenConsole extends Console[ChenProject](new ChenWorkspaceLoader) { +class ChenConsole extends Console[ChenProject](new ChenWorkspaceLoader): - override val config: ConsoleConfig = ChenConsole.defaultConfig + override val config: ConsoleConfig = ChenConsole.defaultConfig - implicit var semantics: Semantics = context.semantics + implicit var semantics: Semantics = context.semantics - // this is set to be `opts.ossdataflow` on initialization of the shell - var ossDataFlowOptions: OssDataFlowOptions = new OssDataFlowOptions() + // this is set to be `opts.ossdataflow` on initialization of the shell + var ossDataFlowOptions: OssDataFlowOptions = new OssDataFlowOptions() - implicit def context: EngineContext = - workspace.getActiveProject - .map(x => x.asInstanceOf[ChenProject].context) - .getOrElse(EngineContext()) + implicit def context: EngineContext = + workspace.getActiveProject + .map(x => x.asInstanceOf[ChenProject].context) + .getOrElse(EngineContext()) - def loadCpg(inputPath: String): Option[Cpg] = { - report("Deprecated. Please use `importCpg` instead") - importCpg(inputPath) - } + def loadCpg(inputPath: String): Option[Cpg] = + report("Deprecated. Please use `importCpg` instead") + importCpg(inputPath) - override def applyDefaultOverlays(cpg: Cpg): Cpg = { - super.applyDefaultOverlays(cpg) - _runAnalyzer(new OssDataFlow(ossDataFlowOptions)) - } + override def applyDefaultOverlays(cpg: Cpg): Cpg = + super.applyDefaultOverlays(cpg) + _runAnalyzer(new OssDataFlow(ossDataFlowOptions)) +end ChenConsole -} +object ChenConsole: -object ChenConsole { - - def banner(): String = - s""" + def banner(): String = + s""" | _ _ _ _ _ __ |/ |_ _ ._ ._ _. o |_ / \\ / \\ / \\ / |_|_ |\\_ | | (/_ | | | | (_| | |_) \\_/ \\_/ \\_/ / | @@ -57,9 +52,7 @@ object ChenConsole { |Version: $version """.stripMargin - def version: String = - getClass.getPackage.getImplementationVersion - - def defaultConfig: ConsoleConfig = new ConsoleConfig() + def version: String = + getClass.getPackage.getImplementationVersion -} + def defaultConfig: ConsoleConfig = new ConsoleConfig() diff --git a/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala b/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala index 796a22b6..6b4cca81 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala @@ -2,25 +2,25 @@ package io.appthreat.chencli.console import io.appthreat.console.{Help, Run} -object Predefined { +object Predefined: - val shared: Seq[String] = - Seq( - "import _root_.io.appthreat.console._", - "import _root_.io.appthreat.chencli.console.ChenConsole._", - "import _root_.io.appthreat.chencli.console.Chen.context", - "import _root_.io.shiftleft.codepropertygraph.Cpg", - "import _root_.io.shiftleft.codepropertygraph.Cpg.docSearchPackages", - "import _root_.io.shiftleft.codepropertygraph.cpgloading._", - "import _root_.io.shiftleft.codepropertygraph.generated._", - "import _root_.io.shiftleft.codepropertygraph.generated.nodes._", - "import _root_.io.shiftleft.codepropertygraph.generated.edges._", - "import _root_.io.appthreat.dataflowengineoss.language._", - "import _root_.io.shiftleft.semanticcpg.language._", - "import overflowdb._", - "import overflowdb.traversal.{`package` => _, help => _, _}", - "import scala.jdk.CollectionConverters._", - """ + val shared: Seq[String] = + Seq( + "import _root_.io.appthreat.console._", + "import _root_.io.appthreat.chencli.console.ChenConsole._", + "import _root_.io.appthreat.chencli.console.Chen.context", + "import _root_.io.shiftleft.codepropertygraph.Cpg", + "import _root_.io.shiftleft.codepropertygraph.Cpg.docSearchPackages", + "import _root_.io.shiftleft.codepropertygraph.cpgloading._", + "import _root_.io.shiftleft.codepropertygraph.generated._", + "import _root_.io.shiftleft.codepropertygraph.generated.nodes._", + "import _root_.io.shiftleft.codepropertygraph.generated.edges._", + "import _root_.io.appthreat.dataflowengineoss.language._", + "import _root_.io.shiftleft.semanticcpg.language._", + "import overflowdb._", + "import overflowdb.traversal.{`package` => _, help => _, _}", + "import scala.jdk.CollectionConverters._", + """ |def reachables(sinkTag: String, sourceTag: String, sourceTags: Array[String])(implicit atom: Cpg): Unit = { | try { | def source=atom.tag.name(sourceTag).parameter @@ -35,14 +35,12 @@ object Predefined { |def reachables(implicit atom: Cpg): Unit = reachables("framework-output", "framework-input", Array("api")) | |""".stripMargin - ) + ) - val forInteractiveShell: Seq[String] = { - shared ++ - Seq("import _root_.io.appthreat.chencli.console.Chen._") ++ - Run.codeForRunCommand().linesIterator ++ - Help.codeForHelpCommand(classOf[ChenConsole]).linesIterator ++ - Seq("ossDataFlowOptions = opts.ossdataflow") - } - -} + val forInteractiveShell: Seq[String] = + shared ++ + Seq("import _root_.io.appthreat.chencli.console.Chen._") ++ + Run.codeForRunCommand().linesIterator ++ + Help.codeForHelpCommand(classOf[ChenConsole]).linesIterator ++ + Seq("ossDataFlowOptions = opts.ossdataflow") +end Predefined diff --git a/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala b/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala index 24d53d4d..ee980dd5 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala @@ -4,23 +4,20 @@ import io.appthreat.console.{BridgeBase, ChenProduct} import java.io.PrintStream -object ReplBridge extends BridgeBase { +object ReplBridge extends BridgeBase: - override val jProduct = ChenProduct + override val jProduct = ChenProduct - def main(args: Array[String]): Unit = { - run(parseConfig(args)) - } + def main(args: Array[String]): Unit = + run(parseConfig(args)) - /** Code that is executed when starting the shell - */ - override def predefLines = - Predefined.forInteractiveShell + /** Code that is executed when starting the shell + */ + override def predefLines = + Predefined.forInteractiveShell - override def greeting = ChenConsole.banner() + override def greeting = ChenConsole.banner() - override def promptStr: String = "chennai" + override def promptStr: String = "chennai" - override def onExitCode: String = "workspace.projects.foreach(_.close)" - -} + override def onExitCode: String = "workspace.projects.foreach(_.close)" diff --git a/pyproject.toml b/pyproject.toml index f8131e34..ce8d170d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "appthreat-chen" -version = "0.6.3" +version = "1.0.0" description = "Code Hierarchy Exploration Net (chen)" authors = ["Team AppThreat "] license = "Apache-2.0" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala index 9a142f97..9e5db2a9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala @@ -3,46 +3,37 @@ package io.shiftleft.semanticcpg import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Properties import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate -object Overlays { +object Overlays: - def appendOverlayName(cpg: Cpg, overlayName: String): Unit = { - new CpgPass(cpg) { - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = { - cpg.metaData.headOption match { - case Some(metaData) => - val newValue = metaData.overlays :+ overlayName - diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) - case None => - System.err.println("Missing metaData block") - } - } - }.createAndApply() - } + def appendOverlayName(cpg: Cpg, overlayName: String): Unit = + new CpgPass(cpg): + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.metaData.headOption match + case Some(metaData) => + val newValue = metaData.overlays :+ overlayName + diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) + case None => + System.err.println("Missing metaData block") + .createAndApply() - def removeLastOverlayName(cpg: Cpg): Unit = { - new CpgPass(cpg) { - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = { - cpg.metaData.headOption match { - case Some(metaData) => - val newValue = metaData.overlays.dropRight(1) - diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) - case None => - System.err.println("Missing metaData block") - } - } - }.createAndApply() - } + def removeLastOverlayName(cpg: Cpg): Unit = + new CpgPass(cpg): + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.metaData.headOption match + case Some(metaData) => + val newValue = metaData.overlays.dropRight(1) + diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) + case None => + System.err.println("Missing metaData block") + .createAndApply() - def appliedOverlays(cpg: Cpg): Seq[String] = { - cpg.metaData.headOption match { - case Some(metaData) => Option(metaData.overlays).getOrElse(Nil) - case None => - System.err.println("Missing metaData block") - List() - } - } - -} + def appliedOverlays(cpg: Cpg): Seq[String] = + cpg.metaData.headOption match + case Some(metaData) => Option(metaData.overlays).getOrElse(Nil) + case None => + System.err.println("Missing metaData block") + List() +end Overlays diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala index c32e36ce..23df465b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala @@ -1,40 +1,31 @@ package io.shiftleft.semanticcpg.accesspath -sealed abstract class AccessElement(name: String) extends Comparable[AccessElement] { - override def toString: String = name - def kind: Int - override def hashCode(): Int = kind + name.hashCode - - override def compareTo(other: AccessElement): Int = { - this.kind.compareTo(other.kind) match { - case 0 => this.name.compareTo(other.toString) - case different => different - } - } -} - -case class ConstantAccess(constant: String) extends AccessElement(constant) { - override def kind: Int = 0x01010101 -} - -case object VariableAccess extends AccessElement("?") { - override def kind: Int = 0x02020202 -} - -case object VariablePointerShift extends AccessElement("") { - override def kind: Int = 0x03030303 -} +sealed abstract class AccessElement(name: String) extends Comparable[AccessElement]: + override def toString: String = name + def kind: Int + override def hashCode(): Int = kind + name.hashCode + + override def compareTo(other: AccessElement): Int = + this.kind.compareTo(other.kind) match + case 0 => this.name.compareTo(other.toString) + case different => different + +case class ConstantAccess(constant: String) extends AccessElement(constant): + override def kind: Int = 0x01010101 + +case object VariableAccess extends AccessElement("?"): + override def kind: Int = 0x02020202 + +case object VariablePointerShift extends AccessElement(""): + override def kind: Int = 0x03030303 // this will eventually get an optional extent (how many bytes wide is the memory load/store) -case object IndirectionAccess extends AccessElement("*") { - override def kind: Int = 0x04040404 -} +case object IndirectionAccess extends AccessElement("*"): + override def kind: Int = 0x04040404 -case object AddressOf extends AccessElement("&") { - override def kind: Int = 0x05050505 -} +case object AddressOf extends AccessElement("&"): + override def kind: Int = 0x05050505 // this will eventually obtain an optional byteOffset -case class PointerShift(logicalOffset: Int) extends AccessElement(s"<$logicalOffset>") { - override def kind: Int = 0x06060606 -} +case class PointerShift(logicalOffset: Int) extends AccessElement(s"<$logicalOffset>"): + override def kind: Int = 0x06060606 diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala index bb899b67..8b46f514 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala @@ -1,437 +1,452 @@ package io.shiftleft.semanticcpg.accesspath -object AccessPath { - - private val empty = new AccessPath(Elements(), List[Elements]()) - - def apply(): AccessPath = empty - - def apply(elements: Elements, exclusions: Seq[Elements]): AccessPath = { - if (elements.isEmpty && exclusions.isEmpty) { - empty - } else { - new AccessPath(elements, exclusions.toList) - } - } - - def isExtensionExcluded(exclusions: Seq[Elements], extension: Elements): Boolean = - exclusions.exists(e => extension.elements.startsWith(e.elements)) - - /** This class contains operations on `Elements` that are only used in the `AccessPath` class and are not part of the - * public API of `Elements` - */ - private implicit class ElementsDecorations(el: Elements) { - - def noOvertaint(start: Int = 0, untilExclusive: Int = el.elements.length): Boolean = { - var idx = start - while (idx < untilExclusive) { - el.elements(idx) match { - case VariablePointerShift | VariableAccess => return false - case _ => - } - idx += 1 - } - true - } - - /** In all sane situations, invertibleTailLength is 0 or 1: - * - we don't expect &, because you cannot take the address of pointer+i (can only take address of rvalue) - * - we don't expect & : The & should have collapsed against a preceding *. An example where this occurs is - * (&(ptr->field))[1], which becomes ptr: * field & <2> * This reads the _next_ field: It doesn't alias with - * ptr->field at all, but reads the next bytes in the struct, after field. - * - * Such code is very un-idiomatic. +object AccessPath: + + private val empty = new AccessPath(Elements(), List[Elements]()) + + def apply(): AccessPath = empty + + def apply(elements: Elements, exclusions: Seq[Elements]): AccessPath = + if elements.isEmpty && exclusions.isEmpty then + empty + else + new AccessPath(elements, exclusions.toList) + + def isExtensionExcluded(exclusions: Seq[Elements], extension: Elements): Boolean = + exclusions.exists(e => extension.elements.startsWith(e.elements)) + + /** This class contains operations on `Elements` that are only used in the `AccessPath` class + * and are not part of the public API of `Elements` */ - def invertibleTailLength: Int = { - var i = 0 - val nElements = el.elements.length - 1 - while (nElements - i > -1) { - el.elements(nElements - i) match { - case AddressOf | VariablePointerShift | _: PointerShift => i += 1 - case _ => return i + private implicit class ElementsDecorations(el: Elements): + + def noOvertaint(start: Int = 0, untilExclusive: Int = el.elements.length): Boolean = + var idx = start + while idx < untilExclusive do + el.elements(idx) match + case VariablePointerShift | VariableAccess => return false + case _ => + idx += 1 + true + + /** In all sane situations, invertibleTailLength is 0 or 1: + * - we don't expect &, because you cannot take the address of pointer+i (can only + * take address of rvalue) + * - we don't expect & : The & should have collapsed against a preceding *. An example + * where this occurs is (&(ptr->field))[1], which becomes ptr: * field & <2> * This + * reads the _next_ field: It doesn't alias with ptr->field at all, but reads the next + * bytes in the struct, after field. + * + * Such code is very un-idiomatic. + */ + def invertibleTailLength: Int = + var i = 0 + val nElements = el.elements.length - 1 + while nElements - i > -1 do + el.elements(nElements - i) match + case AddressOf | VariablePointerShift | _: PointerShift => i += 1 + case _ => return i + i + end ElementsDecorations +end AccessPath + +case class AccessPath(elements: Elements, exclusions: Seq[Elements]): + + import AccessPath.* + + def isEmpty: Boolean = this == AccessPath.empty + + private var cachedHash: Int = 0 + + override def hashCode(): Int = + if cachedHash == 0 then + val computedHash = elements.hashCode() + exclusions.hashCode() ^ 0x404f92ab + cachedHash = if computedHash == 0 then 1 else computedHash + cachedHash + else cachedHash + + // for handling of invertible elements, cf AccessPathAlgebra.md + // FIXME: may need to process invertible tail of `other` better + + def ++(other: Elements): Option[AccessPath] = + if isExtensionExcluded(other) then None + else Some(AccessPath(this.elements ++ other, this.truncateExclusions(other).exclusions)) + + // FIXME: may need to process invertible tail of `other` better + def ++(other: AccessPath): Option[AccessPath] = + (this ++ other.elements).map { appended => + other.exclusions.foldLeft(appended) { case (ap, ex) => ap.addExclusion(ex) } } - } - i - } - } - -} - -case class AccessPath(elements: Elements, exclusions: Seq[Elements]) { - - import AccessPath._ - - def isEmpty: Boolean = this == AccessPath.empty - - private var cachedHash: Int = 0 - - override def hashCode(): Int = { - if (cachedHash == 0) { - val computedHash = elements.hashCode() + exclusions.hashCode() ^ 0x404f92ab - cachedHash = if (computedHash == 0) 1 else computedHash - cachedHash - } else cachedHash - } - - // for handling of invertible elements, cf AccessPathAlgebra.md - // FIXME: may need to process invertible tail of `other` better - - def ++(other: Elements): Option[AccessPath] = { - if (isExtensionExcluded(other)) None - else Some(AccessPath(this.elements ++ other, this.truncateExclusions(other).exclusions)) - } - - // FIXME: may need to process invertible tail of `other` better - def ++(other: AccessPath): Option[AccessPath] = { - (this ++ other.elements).map { appended => - other.exclusions.foldLeft(appended) { case (ap, ex) => ap.addExclusion(ex) } - } - } - - def matchFull(other: AccessPath): FullMatchResult = { - val res = this.matchFull(other.elements) - if ( - res.extensionDiff.isEmpty && res.stepIntoPath.isDefined && other.isExtensionExcluded( - res.stepIntoPath.get.elements - ) - ) { - FullMatchResult(Some(this), None, Elements.empty) - } else res - } - - def matchFull(other: Elements): FullMatchResult = { - val (matchRes, matchDiff) = this.matchAndDiff(other) - matchRes match { - case MatchResult.NO_MATCH => - FullMatchResult(Some(this), None, Elements.empty) - case MatchResult.PREFIX_MATCH | MatchResult.EXACT_MATCH => - FullMatchResult(None, Some(AccessPath(matchDiff, this.exclusions)), Elements.empty) - case MatchResult.VARIABLE_PREFIX_MATCH | MatchResult.VARIABLE_EXACT_MATCH => - FullMatchResult(Some(this), Some(AccessPath(matchDiff, this.exclusions)), Elements.empty) - case MatchResult.EXTENDED_MATCH => - FullMatchResult( - Some(this.addExclusion(matchDiff)), - Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), - matchDiff - ) - case MatchResult.VARIABLE_EXTENDED_MATCH => - FullMatchResult( - Some(this), - Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), - matchDiff - ) - } - } - - def matchAndDiff(other: Elements): (MatchResult.MatchResult, Elements) = { - val thisTail = elements.invertibleTailLength - val otherTail = other.invertibleTailLength - val thisHead = elements.elements.length - thisTail - val otherHead = other.elements.length - otherTail - - val cmpUntil = scala.math.min(thisHead, otherHead) - var idx = 0 - var overTainted = false - while (idx < cmpUntil) { - (elements.elements(idx), other.elements(idx)) match { - case (VariableAccess, VariableAccess) | (_: ConstantAccess, VariableAccess) | - (VariableAccess, _: ConstantAccess) | (VariablePointerShift, VariablePointerShift) | - (_: PointerShift, VariablePointerShift) | (VariablePointerShift, _: PointerShift) => - overTainted = true - case (thisElem, otherElem) => - if (thisElem != otherElem) - return (MatchResult.NO_MATCH, Elements.empty) - } - idx += 1 - } - var done = false - - /** We now try to greedily match more elements. We know that one of the two paths will only contain invertible - * elements. The issue is the following: prefix <1> & x prefix & With greedy matching, we end up with a diff: - * x. If we just did the invert-append algorithm, we would end up with a less precise diff: * <1> & x == * - * & x. - */ - val minlen = scala.math.min(elements.elements.length, other.elements.length) - while (!done && idx < minlen) { - (elements.elements(idx), other.elements(idx)) match { - case (_: PointerShift, VariablePointerShift) | (VariablePointerShift, _: PointerShift) | - (VariablePointerShift, VariablePointerShift) => - overTainted = true - idx += 1 - case (thisElem, otherElem) => - if (thisElem == otherElem) { + + def matchFull(other: AccessPath): FullMatchResult = + val res = this.matchFull(other.elements) + if + res.extensionDiff.isEmpty && res.stepIntoPath.isDefined && other.isExtensionExcluded( + res.stepIntoPath.get.elements + ) + then + FullMatchResult(Some(this), None, Elements.empty) + else res + + def matchFull(other: Elements): FullMatchResult = + val (matchRes, matchDiff) = this.matchAndDiff(other) + matchRes match + case MatchResult.NO_MATCH => + FullMatchResult(Some(this), None, Elements.empty) + case MatchResult.PREFIX_MATCH | MatchResult.EXACT_MATCH => + FullMatchResult(None, Some(AccessPath(matchDiff, this.exclusions)), Elements.empty) + case MatchResult.VARIABLE_PREFIX_MATCH | MatchResult.VARIABLE_EXACT_MATCH => + FullMatchResult( + Some(this), + Some(AccessPath(matchDiff, this.exclusions)), + Elements.empty + ) + case MatchResult.EXTENDED_MATCH => + FullMatchResult( + Some(this.addExclusion(matchDiff)), + Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), + matchDiff + ) + case MatchResult.VARIABLE_EXTENDED_MATCH => + FullMatchResult( + Some(this), + Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), + matchDiff + ) + end match + end matchFull + + def matchAndDiff(other: Elements): (MatchResult.MatchResult, Elements) = + val thisTail = elements.invertibleTailLength + val otherTail = other.invertibleTailLength + val thisHead = elements.elements.length - thisTail + val otherHead = other.elements.length - otherTail + + val cmpUntil = scala.math.min(thisHead, otherHead) + var idx = 0 + var overTainted = false + while idx < cmpUntil do + (elements.elements(idx), other.elements(idx)) match + case (VariableAccess, VariableAccess) | (_: ConstantAccess, VariableAccess) | + (VariableAccess, _: ConstantAccess) | ( + VariablePointerShift, + VariablePointerShift + ) | + (_: PointerShift, VariablePointerShift) | ( + VariablePointerShift, + _: PointerShift + ) => + overTainted = true + case (thisElem, otherElem) => + if thisElem != otherElem then + return (MatchResult.NO_MATCH, Elements.empty) idx += 1 - } else { - done = true - } - } - } - if (thisHead >= otherHead) { - // prefix or exact - val diff = Elements.inverted(other.elements.drop(idx)) ++ Elements.unnormalized(elements.elements.drop(idx)) - - /** we don't need to overtaint if thisTail has variable PointerShift: They can still get excluded e.g. suppose we - * track "a" "b" and encounter "a" <4>. - */ - overTainted |= !other.noOvertaint(otherHead) - if (!overTainted & thisHead == otherHead) (MatchResult.EXACT_MATCH, diff) - else if (overTainted && thisHead == otherHead) (MatchResult.VARIABLE_EXACT_MATCH, diff) - else if (!overTainted && thisHead != otherHead) (MatchResult.PREFIX_MATCH, diff) - else if (overTainted && thisHead != otherHead) (MatchResult.VARIABLE_PREFIX_MATCH, diff) - else throw new RuntimeException() - } else { - // extended - val diff = Elements.inverted(elements.elements.drop(idx)) ++ Elements.unnormalized(other.elements.drop(idx)) - - /** we need to overtaint if any either otherTail or thisTail has variable PointerShift: e.g. suppose we track "a" - * <4> and encounter "a" "b" "c", or suppose that we track "a" and encounter "a" "b" - */ - overTainted |= !elements.noOvertaint(thisHead) | !other.noOvertaint(otherHead) - - if (overTainted) (MatchResult.VARIABLE_EXTENDED_MATCH, diff) - else if (isExtensionExcluded(diff)) (MatchResult.NO_MATCH, Elements.empty) - else (MatchResult.EXTENDED_MATCH, diff) - } - } - - private def truncateExclusions(compareExclusion: Elements): AccessPath = { - if (exclusions.isEmpty) return this - val size = compareExclusion.elements.length - val newExclusions = - exclusions - .filter(_.elements.startsWith(compareExclusion.elements)) - .map(exclusion => Elements.normalized(exclusion.elements.drop(size))) - .sorted - AccessPath(elements, newExclusions) - } - - private def addExclusion(newExclusion: Elements): AccessPath = { - if (newExclusion.noOvertaint()) { - val ex = - Elements.unnormalized(newExclusion.elements.dropRight(newExclusion.invertibleTailLength)) - if (isExtensionExcluded(ex)) return this - val unshadowed = exclusions.filter(!_.elements.startsWith(ex.elements)) - AccessPath(elements, (unshadowed :+ ex).sorted) - } else this - } - - def isExtensionExcluded(extension: Elements): Boolean = { - AccessPath.isExtensionExcluded(this.exclusions, extension) - } - -} + var done = false + + /** We now try to greedily match more elements. We know that one of the two paths will only + * contain invertible elements. The issue is the following: prefix <1> & x prefix & + * With greedy matching, we end up with a diff: x. If we just did the invert-append + * algorithm, we would end up with a less precise diff: * <1> & x == * & x. + */ + val minlen = scala.math.min(elements.elements.length, other.elements.length) + while !done && idx < minlen do + (elements.elements(idx), other.elements(idx)) match + case (_: PointerShift, VariablePointerShift) | ( + VariablePointerShift, + _: PointerShift + ) | + (VariablePointerShift, VariablePointerShift) => + overTainted = true + idx += 1 + case (thisElem, otherElem) => + if thisElem == otherElem then + idx += 1 + else + done = true + if thisHead >= otherHead then + // prefix or exact + val diff = Elements.inverted(other.elements.drop(idx)) ++ Elements.unnormalized( + elements.elements.drop(idx) + ) + + /** we don't need to overtaint if thisTail has variable PointerShift: They can still get + * excluded e.g. suppose we track "a" "b" and encounter "a" <4>. + */ + overTainted |= !other.noOvertaint(otherHead) + if !overTainted & thisHead == otherHead then (MatchResult.EXACT_MATCH, diff) + else if overTainted && thisHead == otherHead then + (MatchResult.VARIABLE_EXACT_MATCH, diff) + else if !overTainted && thisHead != otherHead then (MatchResult.PREFIX_MATCH, diff) + else if overTainted && thisHead != otherHead then + (MatchResult.VARIABLE_PREFIX_MATCH, diff) + else throw new RuntimeException() + else + // extended + val diff = Elements.inverted(elements.elements.drop(idx)) ++ Elements.unnormalized( + other.elements.drop(idx) + ) + + /** we need to overtaint if any either otherTail or thisTail has variable PointerShift: + * e.g. suppose we track "a" <4> and encounter "a" "b" "c", or suppose that we + * track "a" and encounter "a" "b" + */ + overTainted |= !elements.noOvertaint(thisHead) | !other.noOvertaint(otherHead) + + if overTainted then (MatchResult.VARIABLE_EXTENDED_MATCH, diff) + else if isExtensionExcluded(diff) then (MatchResult.NO_MATCH, Elements.empty) + else (MatchResult.EXTENDED_MATCH, diff) + end if + end matchAndDiff + + private def truncateExclusions(compareExclusion: Elements): AccessPath = + if exclusions.isEmpty then return this + val size = compareExclusion.elements.length + val newExclusions = + exclusions + .filter(_.elements.startsWith(compareExclusion.elements)) + .map(exclusion => Elements.normalized(exclusion.elements.drop(size))) + .sorted + AccessPath(elements, newExclusions) + + private def addExclusion(newExclusion: Elements): AccessPath = + if newExclusion.noOvertaint() then + val ex = + Elements.unnormalized( + newExclusion.elements.dropRight(newExclusion.invertibleTailLength) + ) + if isExtensionExcluded(ex) then return this + val unshadowed = exclusions.filter(!_.elements.startsWith(ex.elements)) + AccessPath(elements, (unshadowed :+ ex).sorted) + else this + + def isExtensionExcluded(extension: Elements): Boolean = + AccessPath.isExtensionExcluded(this.exclusions, extension) +end AccessPath sealed trait MatchResult -object MatchResult extends Enumeration { - type MatchResult = Value - val NO_MATCH, EXACT_MATCH, VARIABLE_EXACT_MATCH, PREFIX_MATCH, VARIABLE_PREFIX_MATCH, EXTENDED_MATCH, - VARIABLE_EXTENDED_MATCH = Value -} +object MatchResult extends Enumeration: + type MatchResult = Value + val NO_MATCH, EXACT_MATCH, VARIABLE_EXACT_MATCH, PREFIX_MATCH, VARIABLE_PREFIX_MATCH, + EXTENDED_MATCH, VARIABLE_EXTENDED_MATCH = Value /** Result of `matchFull` comparison * * @param stepOverPath - * the unaffected part of the access path. Some(this) for no match, None for perfect match; may have additional - * exclusions to this. + * the unaffected part of the access path. Some(this) for no match, None for perfect match; may + * have additional exclusions to this. * @param stepIntoPath - * The affected part of the access path, mapped to be relative to this stepIntoPath.isDefined if and only if there is - * a match in paths, i.e. if the call can affect the tracked variable at all. Outside of overtainting, if - * stepIntoPath.isDefined && stepIntoPath.elements.nonEmpty then: path.elements == other.elements ++ - * path.matchFull(other).stepIntoPath.get.elements extensionDiff.isEmpty + * The affected part of the access path, mapped to be relative to this stepIntoPath.isDefined if + * and only if there is a match in paths, i.e. if the call can affect the tracked variable at + * all. Outside of overtainting, if stepIntoPath.isDefined && stepIntoPath.elements.nonEmpty + * then: path.elements == other.elements ++ path.matchFull(other).stepIntoPath.get.elements + * extensionDiff.isEmpty * * @param extensionDiff - * extensionDiff is non empty if and only if a proper subset is affected. Outside of over tainting, if extensionDiff - * is non empty then: path.elements ++ path.matchFull(other).extensionDiff == other.elements + * extensionDiff is non empty if and only if a proper subset is affected. Outside of over + * tainting, if extensionDiff is non empty then: path.elements ++ + * path.matchFull(other).extensionDiff == other.elements * path.matchFull(other).stepIntoPath.get.elements.isEmpty * * Invariants: * - Exclusions have no invertible tail - * - Only paths without overTaint can have exclusions TODO: Figure out sensible assertions to defend these invariants + * - Only paths without overTaint can have exclusions TODO: Figure out sensible assertions to + * defend these invariants */ case class FullMatchResult( stepOverPath: Option[AccessPath], stepIntoPath: Option[AccessPath], extensionDiff: Elements -) { - def hasMatch: Boolean = stepIntoPath.nonEmpty -} +): + def hasMatch: Boolean = stepIntoPath.nonEmpty +end FullMatchResult -/** For handling of invertible elements, cf AccessPathAlgebra.md. The general rule is that elements concatenate - * normally, except for: +/** For handling of invertible elements, cf AccessPathAlgebra.md. The general rule is that elements + * concatenate normally, except for: * - * Elements(&) ++ Elements(*) == Elements() Elements(*) ++ Elements(&) == Elements() Elements(<0>) == Elements() - * Elements() ++ Elements() == Elements() Elements() ++ Elements() == Elements() Elements() ++ - * Elements() == Elements() Elements() ++ Elements() == Elements() + * Elements(&) ++ Elements(*) == Elements() Elements(*) ++ Elements(&) == Elements() Elements(<0>) + * \== Elements() Elements() ++ Elements() == Elements() Elements() ++ Elements() + * \== Elements() Elements() ++ Elements() == Elements() Elements() ++ Elements() + * \== Elements() * - * From this, once can see that , * and & are invertible, is idempotent and <0> is a convoluted way of - * describing and empty sequence of tokens. Nevertheless, we mostly consider * as noninvertible (because it is, in real - * computers!) and as invertible (because it is in real computers, we just don't know the offset) + * From this, once can see that , * and & are invertible, is idempotent and <0> is a + * convoluted way of describing and empty sequence of tokens. Nevertheless, we mostly consider * as + * noninvertible (because it is, in real computers!) and as invertible (because it is in real + * computers, we just don't know the offset) * - * Elements get a private constructor. Users should use the no-argument Elements.apply() factory method to get an empty - * path, and the specific concat operators for building up pathes. The Elements.normalized(iter) factory method serves - * to build this in bulk. + * Elements get a private constructor. Users should use the no-argument Elements.apply() factory + * method to get an empty path, and the specific concat operators for building up pathes. The + * Elements.normalized(iter) factory method serves to build this in bulk. * * The unnormalized factory method is more of an escape hatch. * - * The elements field should never be mutated outside of this file: We compare and hash Elements by their contents, not - * by identity, and this breaks in case of mutation. + * The elements field should never be mutated outside of this file: We compare and hash Elements by + * their contents, not by identity, and this breaks in case of mutation. * - * The reason for using a mutable Array instead of an immutable Vector is that this is the lightest weight - * datastructure for the job. + * The reason for using a mutable Array instead of an immutable Vector is that this is the lightest + * weight datastructure for the job. * - * The reason for making this non-private is simply that it is truly annoying to write wrappers for all possible uses. + * The reason for making this non-private is simply that it is truly annoying to write wrappers for + * all possible uses. */ // TODO: Figure out sensible assertions to defend invariant that the empty instance is // the only empty elements instance, i.e., assert that elems.isEmpty implies elems eq // Elements.empty -object Elements { - val empty = new Elements() - - def apply(): Elements = empty - - def normalized(elems: IterableOnce[AccessElement]): Elements = - destructiveNormalized(elems.iterator.toArray) - - def normalized(elems: AccessElement*): Elements = - destructiveNormalized(elems.toArray) - - def unnormalized(elems: IterableOnce[AccessElement]): Elements = - newIfNonEmpty(elems.iterator.toArray) - - def newIfNonEmpty(elems: Array[AccessElement]): Elements = { - if (!elems.isEmpty) new Elements(elems) - else empty - } - - def inverted(elems: Iterable[AccessElement]): Elements = { - val invertedElems: Array[AccessElement] = elems.toArray.reverse.map { - case AddressOf => IndirectionAccess - case IndirectionAccess => AddressOf - case PointerShift(idx) => PointerShift(-idx) - case VariablePointerShift => VariablePointerShift - case _ => throw new RuntimeException(s"Cannot invert ${Elements.unnormalized(elems)}") - } - newIfNonEmpty(invertedElems) - } - - def noOvertaint(elems: Iterable[AccessElement]): Boolean = - elems.forall(_ != VariableAccess) - - private def destructiveNormalized(elems: Array[AccessElement]): Elements = { - var idxRight = 0 - var idxLeft = -1 - while (idxRight < elems.length) { - val nextE = elems(idxRight) - nextE match { - case shift: PointerShift if shift.logicalOffset == 0 => - // nothing to do - case _ => - if (idxLeft == -1) { - idxLeft = 0 - elems(0) = nextE - } else { - val lastE = elems(idxLeft) - (lastE, nextE) match { - case (last: PointerShift, next: PointerShift) => - val newShift = last.logicalOffset + next.logicalOffset - if (newShift != 0) elems(idxLeft) = PointerShift(newShift) - else idxLeft -= 1 - case (VariablePointerShift, _: PointerShift) | (VariablePointerShift, VariablePointerShift) => - case (_: PointerShift, VariablePointerShift) => - elems(idxLeft) = VariablePointerShift - case (AddressOf, IndirectionAccess) => - idxLeft -= 1 - case (IndirectionAccess, AddressOf) => - idxLeft -= 1 // WRONG but useful, cf comment for `Elements.:+` - case _ => - idxLeft += 1 - elems(idxLeft) = nextE - } - } - } - idxRight += 1 - } - newIfNonEmpty(elems.take(idxLeft + 1)) - } - -} - -final class Elements(val elements: Array[AccessElement] = Array[AccessElement]()) extends Comparable[Elements] { - - def isEmpty: Boolean = elements.isEmpty - - override def toString: String = s"Elements(${elements.mkString(",")})" - - override def equals(other: Any): Boolean = { - other match { - case otherElements: Elements => - Array.equals(elements.asInstanceOf[Array[AnyRef]], otherElements.elements.asInstanceOf[Array[AnyRef]]) - case _ => false - } - } - - override def hashCode(): Int = java.util.Arrays.hashCode(elements.asInstanceOf[Array[AnyRef]]) - - override def compareTo(other: Elements): Int = { - val until = scala.math.min(elements.length, other.elements.length) - var idx = 0 - while (idx < until) { - elements(idx).compareTo(other.elements(idx)) match { - case 0 => - case difference => return difference - } - idx += 1 - } - if (idx < elements.length) +1 - else if (idx < other.elements.length) -1 - else 0 - } - - def ++(otherElements: Elements): Elements = { - - if (elements.isEmpty) return otherElements - if (otherElements.isEmpty) return this - - var buf = None: Option[AccessElement] - val otherSize = otherElements.elements.length - var idx = 0 - val until = scala.math.min(elements.length, otherSize) - var done = false - - while (idx < until & !done) { - (elements(elements.length - idx - 1), otherElements.elements(idx)) match { - case (AddressOf, IndirectionAccess) => - idx += 1 - case (IndirectionAccess, AddressOf) => - // WRONG but useful, cf comment for `Elements.:+` - idx += 1 - case (VariablePointerShift, VariablePointerShift) | (_: PointerShift, VariablePointerShift) | - (VariablePointerShift, _: PointerShift) => - done = true - buf = Some(VariablePointerShift) - idx += 1 - case (last: PointerShift, first: PointerShift) => - val newOffset = last.logicalOffset + first.logicalOffset - if (newOffset != 0) { - done = true - buf = Some(PointerShift(newOffset)) - } - idx += 1 - case _ => - done = true - } - } - val sz = elements.length + otherSize - 2 * idx + (if (buf.isDefined) 1 else 0) - val res = Array.fill(sz) { null }: Array[AccessElement] - elements.copyToArray(res, 0, elements.length - idx) - if (buf.isDefined) { - res(elements.length - idx) = buf.get - java.lang.System.arraycopy(otherElements.elements, idx, res, elements.length - idx + 1, otherSize - idx) - } else { - java.lang.System.arraycopy(otherElements.elements, idx, res, elements.length - idx, otherSize - idx) - } - Elements.newIfNonEmpty(res) - } - -} +object Elements: + val empty = new Elements() + + def apply(): Elements = empty + + def normalized(elems: IterableOnce[AccessElement]): Elements = + destructiveNormalized(elems.iterator.toArray) + + def normalized(elems: AccessElement*): Elements = + destructiveNormalized(elems.toArray) + + def unnormalized(elems: IterableOnce[AccessElement]): Elements = + newIfNonEmpty(elems.iterator.toArray) + + def newIfNonEmpty(elems: Array[AccessElement]): Elements = + if !elems.isEmpty then new Elements(elems) + else empty + + def inverted(elems: Iterable[AccessElement]): Elements = + val invertedElems: Array[AccessElement] = elems.toArray.reverse.map { + case AddressOf => IndirectionAccess + case IndirectionAccess => AddressOf + case PointerShift(idx) => PointerShift(-idx) + case VariablePointerShift => VariablePointerShift + case _ => throw new RuntimeException(s"Cannot invert ${Elements.unnormalized(elems)}") + } + newIfNonEmpty(invertedElems) + + def noOvertaint(elems: Iterable[AccessElement]): Boolean = + elems.forall(_ != VariableAccess) + + private def destructiveNormalized(elems: Array[AccessElement]): Elements = + var idxRight = 0 + var idxLeft = -1 + while idxRight < elems.length do + val nextE = elems(idxRight) + nextE match + case shift: PointerShift if shift.logicalOffset == 0 => + // nothing to do + case _ => + if idxLeft == -1 then + idxLeft = 0 + elems(0) = nextE + else + val lastE = elems(idxLeft) + (lastE, nextE) match + case (last: PointerShift, next: PointerShift) => + val newShift = last.logicalOffset + next.logicalOffset + if newShift != 0 then elems(idxLeft) = PointerShift(newShift) + else idxLeft -= 1 + case (VariablePointerShift, _: PointerShift) | ( + VariablePointerShift, + VariablePointerShift + ) => + case (_: PointerShift, VariablePointerShift) => + elems(idxLeft) = VariablePointerShift + case (AddressOf, IndirectionAccess) => + idxLeft -= 1 + case (IndirectionAccess, AddressOf) => + idxLeft -= 1 // WRONG but useful, cf comment for `Elements.:+` + case _ => + idxLeft += 1 + elems(idxLeft) = nextE + end match + idxRight += 1 + end while + newIfNonEmpty(elems.take(idxLeft + 1)) + end destructiveNormalized +end Elements + +final class Elements(val elements: Array[AccessElement] = Array[AccessElement]()) + extends Comparable[Elements]: + + def isEmpty: Boolean = elements.isEmpty + + override def toString: String = s"Elements(${elements.mkString(",")})" + + override def equals(other: Any): Boolean = + other match + case otherElements: Elements => + Array.equals( + elements.asInstanceOf[Array[AnyRef]], + otherElements.elements.asInstanceOf[Array[AnyRef]] + ) + case _ => false + + override def hashCode(): Int = java.util.Arrays.hashCode(elements.asInstanceOf[Array[AnyRef]]) + + override def compareTo(other: Elements): Int = + val until = scala.math.min(elements.length, other.elements.length) + var idx = 0 + while idx < until do + elements(idx).compareTo(other.elements(idx)) match + case 0 => + case difference => return difference + idx += 1 + if idx < elements.length then +1 + else if idx < other.elements.length then -1 + else 0 + + def ++(otherElements: Elements): Elements = + + if elements.isEmpty then return otherElements + if otherElements.isEmpty then return this + + var buf = None: Option[AccessElement] + val otherSize = otherElements.elements.length + var idx = 0 + val until = scala.math.min(elements.length, otherSize) + var done = false + + while idx < until & !done do + (elements(elements.length - idx - 1), otherElements.elements(idx)) match + case (AddressOf, IndirectionAccess) => + idx += 1 + case (IndirectionAccess, AddressOf) => + // WRONG but useful, cf comment for `Elements.:+` + idx += 1 + case (VariablePointerShift, VariablePointerShift) | ( + _: PointerShift, + VariablePointerShift + ) | + (VariablePointerShift, _: PointerShift) => + done = true + buf = Some(VariablePointerShift) + idx += 1 + case (last: PointerShift, first: PointerShift) => + val newOffset = last.logicalOffset + first.logicalOffset + if newOffset != 0 then + done = true + buf = Some(PointerShift(newOffset)) + idx += 1 + case _ => + done = true + end while + val sz = elements.length + otherSize - 2 * idx + (if buf.isDefined then 1 else 0) + val res = Array.fill(sz) { null }: Array[AccessElement] + elements.copyToArray(res, 0, elements.length - idx) + if buf.isDefined then + res(elements.length - idx) = buf.get + java.lang.System.arraycopy( + otherElements.elements, + idx, + res, + elements.length - idx + 1, + otherSize - idx + ) + else + java.lang.System.arraycopy( + otherElements.elements, + idx, + res, + elements.length - idx, + otherSize - idx + ) + Elements.newIfNonEmpty(res) + end ++ +end Elements diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala index f1af0f8d..6a66cd22 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala @@ -1,48 +1,34 @@ package io.shiftleft.semanticcpg.accesspath -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* trait TrackedBase case class TrackedNamedVariable(name: String) extends TrackedBase -case class TrackedReturnValue(call: CallRepr) extends TrackedBase { - override def toString: String = { - s"TrackedReturnValue(${call.code})" - } -} -case class TrackedLiteral(literal: Literal) extends TrackedBase { - override def toString: String = { - s"TrackedLiteral(${literal.code})" - } -} +case class TrackedReturnValue(call: CallRepr) extends TrackedBase: + override def toString: String = + s"TrackedReturnValue(${call.code})" +case class TrackedLiteral(literal: Literal) extends TrackedBase: + override def toString: String = + s"TrackedLiteral(${literal.code})" -sealed trait TrackedMethodOrTypeRef extends TrackedBase { - def code: String +sealed trait TrackedMethodOrTypeRef extends TrackedBase: + def code: String - override def toString: String = { - s"TrackedMethodOrTypeRef($code)" - } -} + override def toString: String = + s"TrackedMethodOrTypeRef($code)" -case class TrackedMethod(method: MethodRef) extends TrackedMethodOrTypeRef { - override def code: String = method.code -} -case class TrackedTypeRef(typeRef: TypeRef) extends TrackedMethodOrTypeRef { - override def code: String = typeRef.code -} +case class TrackedMethod(method: MethodRef) extends TrackedMethodOrTypeRef: + override def code: String = method.code +case class TrackedTypeRef(typeRef: TypeRef) extends TrackedMethodOrTypeRef: + override def code: String = typeRef.code -case class TrackedAlias(argIndex: Int) extends TrackedBase { - override def toString: String = { - s"TrackedAlias($argIndex)" - } -} +case class TrackedAlias(argIndex: Int) extends TrackedBase: + override def toString: String = + s"TrackedAlias($argIndex)" -object TrackedUnknown extends TrackedBase { - override def toString: String = { - "TrackedUnknown" - } -} -object TrackedFormalReturn extends TrackedBase { - override def toString: String = { - "TrackedFormalReturn" - } -} +object TrackedUnknown extends TrackedBase: + override def toString: String = + "TrackedUnknown" +object TrackedFormalReturn extends TrackedBase: + override def toString: String = + "TrackedFormalReturn" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala index b557344a..2ec9ebf9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala @@ -2,123 +2,121 @@ package io.shiftleft.semanticcpg.codedumper import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Local, Method, NewLocation} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.IOUtils import org.slf4j.{Logger, LoggerFactory} import java.nio.file.Paths import scala.util.{Failure, Success, Try} -object CodeDumper { +object CodeDumper: - private val logger: Logger = LoggerFactory.getLogger(getClass) + private val logger: Logger = LoggerFactory.getLogger(getClass) - def arrow(locationFullName: Option[String] = None): CharSequence = s"/* <=== ${locationFullName.getOrElse("")} */ " + def arrow(locationFullName: Option[String] = None): CharSequence = + s"/* <=== ${locationFullName.getOrElse("")} */ " - private val supportedLanguages = - Set( - Languages.C, - Languages.NEWC, - Languages.GHIDRA, - Languages.JAVA, - Languages.JAVASRC, - Languages.JSSRC, - Languages.PYTHON, - Languages.PYTHONSRC - ) + private val supportedLanguages = + Set( + Languages.C, + Languages.NEWC, + Languages.GHIDRA, + Languages.JAVA, + Languages.JAVASRC, + Languages.JSSRC, + Languages.PYTHON, + Languages.PYTHONSRC + ) - private def toAbsolutePath(path: String, rootPath: String): String = { - val absolutePath = Paths.get(path) match { - case p if p.isAbsolute => p - case _ if rootPath.endsWith(path) => Paths.get(rootPath) - case p => Paths.get(rootPath, p.toString) - } - absolutePath.normalize().toString - } + private def toAbsolutePath(path: String, rootPath: String): String = + val absolutePath = Paths.get(path) match + case p if p.isAbsolute => p + case _ if rootPath.endsWith(path) => Paths.get(rootPath) + case p => Paths.get(rootPath, p.toString) + absolutePath.normalize().toString - /** Dump string representation of code at given `location`. - */ - def dump( - location: NewLocation, - language: Option[String], - rootPath: Option[String], - highlight: Boolean, - withArrow: Boolean = true - ): String = { - (location.node, language) match { - case (None, _) => - logger.warn("Empty `location.node` encountered") - "" - case (_, None) => - logger.debug("dump not supported; language not set in CPG") - "" - case (_, Some(lang)) if !supportedLanguages.contains(lang) => - logger.debug(s"dump not supported for language '$lang'") - "" - case (Some(node), Some(lang)) => - val method: Option[Method] = node match { - case n: Method => Some(n) - case n: Expression => Some(n.method) - case n: Local => n.method.headOption - case _ => None - } - method - .collect { - case m: Method if m.lineNumber.isDefined && m.lineNumberEnd.isDefined => - val rawCode = if (lang == Languages.GHIDRA || lang == Languages.JAVA) { - val lines = m.code.split("\n") - lines.zipWithIndex - .map { case (line, lineNo) => - if (lineNo == 0 && withArrow) { - s"$line ${arrow(Option(m.fullName))}" - } else - line - } - .mkString("\n") - } else { - val filename = rootPath.map(toAbsolutePath(location.filename, _)).getOrElse(location.filename) - code(filename, m.lineNumber.get, m.lineNumberEnd.get, location.lineNumber, Option(m.fullName)) - } - if (highlight) { - SourceHighlighter.highlight(Source(rawCode, lang)) - } else { - Some(rawCode) - } - } - .flatten - .getOrElse("") - } - } + /** Dump string representation of code at given `location`. + */ + def dump( + location: NewLocation, + language: Option[String], + rootPath: Option[String], + highlight: Boolean, + withArrow: Boolean = true + ): String = + (location.node, language) match + case (None, _) => + logger.warn("Empty `location.node` encountered") + "" + case (_, None) => + logger.debug("dump not supported; language not set in CPG") + "" + case (_, Some(lang)) if !supportedLanguages.contains(lang) => + logger.debug(s"dump not supported for language '$lang'") + "" + case (Some(node), Some(lang)) => + val method: Option[Method] = node match + case n: Method => Some(n) + case n: Expression => Some(n.method) + case n: Local => n.method.headOption + case _ => None + method + .collect { + case m: Method if m.lineNumber.isDefined && m.lineNumberEnd.isDefined => + val rawCode = if lang == Languages.GHIDRA || lang == Languages.JAVA then + val lines = m.code.split("\n") + lines.zipWithIndex + .map { case (line, lineNo) => + if lineNo == 0 && withArrow then + s"$line ${arrow(Option(m.fullName))}" + else + line + } + .mkString("\n") + else + val filename = rootPath.map( + toAbsolutePath(location.filename, _) + ).getOrElse(location.filename) + code( + filename, + m.lineNumber.get, + m.lineNumberEnd.get, + location.lineNumber, + Option(m.fullName) + ) + if highlight then + SourceHighlighter.highlight(Source(rawCode, lang)) + else + Some(rawCode) + } + .flatten + .getOrElse("") - /** For a given `filename`, `startLine`, and `endLine`, return the corresponding code by reading it from the file. If - * `lineToHighlight` is defined, then a line containing an arrow (as a source code comment) is included right before - * that line. - */ - def code( - filename: String, - startLine: Integer, - endLine: Integer, - lineToHighlight: Option[Integer] = None, - locationFullName: Option[String] = None - ): String = { - Try(IOUtils.readLinesInFile(Paths.get(filename))) match { - case Failure(exception) => - logger.warn(s"error reading from: '$filename'", exception) - "" - case Success(lines) => - lines - .slice(startLine - 1, endLine) - .zipWithIndex - .map { case (line, lineNo) => - if (lineToHighlight.isDefined && lineNo == lineToHighlight.get - startLine) { - s"$line ${arrow(locationFullName)}" - } else { - line - } - } - .mkString("\n") - } - - } - -} + /** For a given `filename`, `startLine`, and `endLine`, return the corresponding code by reading + * it from the file. If `lineToHighlight` is defined, then a line containing an arrow (as a + * source code comment) is included right before that line. + */ + def code( + filename: String, + startLine: Integer, + endLine: Integer, + lineToHighlight: Option[Integer] = None, + locationFullName: Option[String] = None + ): String = + Try(IOUtils.readLinesInFile(Paths.get(filename))) match + case Failure(exception) => + logger.warn(s"error reading from: '$filename'", exception) + "" + case Success(lines) => + lines + .slice(startLine - 1, endLine) + .zipWithIndex + .map { case (line, lineNo) => + if lineToHighlight.isDefined && lineNo == lineToHighlight.get - startLine + then + s"$line ${arrow(locationFullName)}" + else + line + } + .mkString("\n") +end CodeDumper diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala index d009b692..15bdc869 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala @@ -6,35 +6,38 @@ import org.slf4j.{Logger, LoggerFactory} import scala.sys.process.Process -/** language must be one of io.shiftleft.codepropertygraph.generated.Languages TODO: generate enums instead of Strings - * for the languages +/** language must be one of io.shiftleft.codepropertygraph.generated.Languages TODO: generate enums + * instead of Strings for the languages */ case class Source(code: String, language: String) -object SourceHighlighter { - private val logger: Logger = LoggerFactory.getLogger(SourceHighlighter.getClass) +object SourceHighlighter: + private val logger: Logger = LoggerFactory.getLogger(SourceHighlighter.getClass) - def highlight(source: Source): Option[String] = { - val langFlag = source.language match { - case Languages.C | Languages.NEWC | Languages.GHIDRA => "-sC" - case Languages.JAVA | Languages.JAVASRC => "-sJava" - case Languages.JSSRC | Languages.JAVASCRIPT => "-sJavascript" - case Languages.PYTHON | Languages.PYTHONSRC => "-sPython" - case other => throw new RuntimeException(s"Attempting to call highlighter on unsupported language: $other") - } + def highlight(source: Source): Option[String] = + val langFlag = source.language match + case Languages.C | Languages.NEWC | Languages.GHIDRA => "-sC" + case Languages.JAVA | Languages.JAVASRC => "-sJava" + case Languages.JSSRC | Languages.JAVASCRIPT => "-sJavascript" + case Languages.PYTHON | Languages.PYTHONSRC => "-sPython" + case other => throw new RuntimeException( + s"Attempting to call highlighter on unsupported language: $other" + ) - val tmpSrcFile = File.newTemporaryFile("dump") - tmpSrcFile.writeText(source.code) - try { - val highlightedCode = Process(Seq("source-highlight-esc.sh", tmpSrcFile.path.toString, langFlag)).!! - Some(highlightedCode) - } catch { - case exception: Exception => - logger.debug("syntax highlighting not working. Is `source-highlight` installed?", exception) - Some(source.code) - } finally { - tmpSrcFile.delete() - } - } - -} + val tmpSrcFile = File.newTemporaryFile("dump") + tmpSrcFile.writeText(source.code) + try + val highlightedCode = + Process(Seq("source-highlight-esc.sh", tmpSrcFile.path.toString, langFlag)).!! + Some(highlightedCode) + catch + case exception: Exception => + logger.debug( + "syntax highlighting not working. Is `source-highlight` installed?", + exception + ) + Some(source.code) + finally + tmpSrcFile.delete() + end highlight +end SourceHighlighter diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala index f68afc41..9aeef7c0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala @@ -3,21 +3,18 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, MethodParameterOut} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -class AstGenerator { +class AstGenerator: - private val edgeType = EdgeTypes.AST + private val edgeType = EdgeTypes.AST - def generate(astRoot: AstNode): Graph = { - def shouldBeDisplayed(v: AstNode): Boolean = !v.isInstanceOf[MethodParameterOut] - val vertices = astRoot.ast.filter(shouldBeDisplayed).l - val edges = vertices.flatMap(v => - v.astChildren.filter(shouldBeDisplayed).map { child => - Edge(v, child, edgeType = edgeType) - } - ) - Graph(vertices, edges) - } - -} + def generate(astRoot: AstNode): Graph = + def shouldBeDisplayed(v: AstNode): Boolean = !v.isInstanceOf[MethodParameterOut] + val vertices = astRoot.ast.filter(shouldBeDisplayed).l + val edges = vertices.flatMap(v => + v.astChildren.filter(shouldBeDisplayed).map { child => + Edge(v, child, edgeType = edgeType) + } + ) + Graph(vertices, edges) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala index d06bb645..63b366a9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala @@ -3,34 +3,36 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{Method, StoredNode} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable -class CallGraphGenerator { +class CallGraphGenerator: - def generate(cpg: Cpg): Graph = { - val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] - val vertices = cpg.method.l - val edges = for { - srcMethod <- vertices - _ = storeInSubgraph(srcMethod, subgraph) - child <- srcMethod.call - tgt <- child.callOut - } yield { - storeInSubgraph(tgt, subgraph) - Edge(srcMethod, tgt, label = child.dispatchType.stripSuffix("_DISPATCH")) - } - Graph(vertices, edges.distinct, subgraph.toMap) - } + def generate(cpg: Cpg): Graph = + val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] + val vertices = cpg.method.l + val edges = + for + srcMethod <- vertices + _ = storeInSubgraph(srcMethod, subgraph) + child <- srcMethod.call + tgt <- child.callOut + yield + storeInSubgraph(tgt, subgraph) + Edge(srcMethod, tgt, label = child.dispatchType.stripSuffix("_DISPATCH")) + Graph(vertices, edges.distinct, subgraph.toMap) - def storeInSubgraph(method: Method, subgraph: mutable.Map[String, Seq[StoredNode]]): Unit = { - method._typeDeclViaAstIn match { - case Some(typeDeclName) => - subgraph.put(typeDeclName.fullName, subgraph.getOrElse(typeDeclName.fullName, Seq()) ++ Seq(method)) - case None => - subgraph.put(method.astParentFullName, subgraph.getOrElse(method.astParentFullName, Seq()) ++ Seq(method)) - } - } - -} + def storeInSubgraph(method: Method, subgraph: mutable.Map[String, Seq[StoredNode]]): Unit = + method._typeDeclViaAstIn match + case Some(typeDeclName) => + subgraph.put( + typeDeclName.fullName, + subgraph.getOrElse(typeDeclName.fullName, Seq()) ++ Seq(method) + ) + case None => + subgraph.put( + method.astParentFullName, + subgraph.getOrElse(method.astParentFullName, Seq()) ++ Seq(method) + ) +end CallGraphGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala index 2507bcf9..fe581877 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala @@ -4,13 +4,11 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.StoredNode import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.Edge -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* -class CdgGenerator extends CfgGenerator { +class CdgGenerator extends CfgGenerator: - override val edgeType: String = EdgeTypes.CDG + override val edgeType: String = EdgeTypes.CDG - override def expand(v: StoredNode): Iterator[Edge] = - v._cdgOut.map(node => Edge(v, node, edgeType = edgeType)) - -} + override def expand(v: StoredNode): Iterator[Edge] = + v._cdgOut.map(node => Edge(v, node, edgeType = edgeType)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala index 45999fbe..aaaae82d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala @@ -1,61 +1,62 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.Node -class CfgGenerator { - - val edgeType: String = EdgeTypes.CFG - - def generate(methodNode: Method): Graph = { - val vertices = methodNode.cfgNode.l ++ List(methodNode, methodNode.methodReturn) ++ methodNode.parameter.l - val verticesToDisplay = vertices.filter(cfgNodeShouldBeDisplayed) - - def edgesToDisplay(srcNode: StoredNode, visited: List[StoredNode] = List()): List[Edge] = { - if (visited.contains(srcNode)) { - List() - } else { - val children = expand(srcNode).filter(x => vertices.contains(x.dst)) - val (visible, invisible) = children.partition(x => cfgNodeShouldBeDisplayed(x.dst)) - visible.toList ++ invisible.toList.flatMap { n => - edgesToDisplay(n.dst, visited ++ List(srcNode)).map(y => Edge(srcNode, y.dst, edgeType = edgeType)) +class CfgGenerator: + + val edgeType: String = EdgeTypes.CFG + + def generate(methodNode: Method): Graph = + val vertices = methodNode.cfgNode.l ++ List( + methodNode, + methodNode.methodReturn + ) ++ methodNode.parameter.l + val verticesToDisplay = vertices.filter(cfgNodeShouldBeDisplayed) + + def edgesToDisplay(srcNode: StoredNode, visited: List[StoredNode] = List()): List[Edge] = + if visited.contains(srcNode) then + List() + else + val children = expand(srcNode).filter(x => vertices.contains(x.dst)) + val (visible, invisible) = children.partition(x => cfgNodeShouldBeDisplayed(x.dst)) + visible.toList ++ invisible.toList.flatMap { n => + edgesToDisplay(n.dst, visited ++ List(srcNode)).map(y => + Edge(srcNode, y.dst, edgeType = edgeType) + ) + } + + val edges = verticesToDisplay.flatMap { v => + edgesToDisplay(v) + }.distinct + + val allIdsReferencedByEdges = edges.flatMap { edge => + Set(edge.src.id, edge.dst.id) } - } - } - - val edges = verticesToDisplay.flatMap { v => - edgesToDisplay(v) - }.distinct - - val allIdsReferencedByEdges = edges.flatMap { edge => - Set(edge.src.id, edge.dst.id) - } - - Graph( - verticesToDisplay - .filter(node => allIdsReferencedByEdges.contains(node.id)), - edges - ) - } - - protected def expand(v: StoredNode): Iterator[Edge] = - v._cfgOut.map(node => Edge(v, node, edgeType = edgeType)) - - private def isConditionInControlStructure(v: Node): Boolean = v match { - case id: Identifier => id.astParent.isControlStructure - case _ => false - } - - private def cfgNodeShouldBeDisplayed(v: Node): Boolean = - isConditionInControlStructure(v) || - !(v.isInstanceOf[Literal] || - v.isInstanceOf[Identifier] || - v.isInstanceOf[Block] || - v.isInstanceOf[ControlStructure] || - v.isInstanceOf[JumpTarget] || - v.isInstanceOf[MethodParameterIn]) - -} + + Graph( + verticesToDisplay + .filter(node => allIdsReferencedByEdges.contains(node.id)), + edges + ) + end generate + + protected def expand(v: StoredNode): Iterator[Edge] = + v._cfgOut.map(node => Edge(v, node, edgeType = edgeType)) + + private def isConditionInControlStructure(v: Node): Boolean = v match + case id: Identifier => id.astParent.isControlStructure + case _ => false + + private def cfgNodeShouldBeDisplayed(v: Node): Boolean = + isConditionInControlStructure(v) || + !(v.isInstanceOf[Literal] || + v.isInstanceOf[Identifier] || + v.isInstanceOf[Block] || + v.isInstanceOf[ControlStructure] || + v.isInstanceOf[JumpTarget] || + v.isInstanceOf[MethodParameterIn]) +end CfgGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala index c0d7da4a..9aa44011 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala @@ -2,14 +2,11 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.nodes.AstNode -object DotAstGenerator { +object DotAstGenerator: - def dotAst[T <: AstNode](traversal: Iterator[T]): Iterator[String] = - traversal.map(dotAst) + def dotAst[T <: AstNode](traversal: Iterator[T]): Iterator[String] = + traversal.map(dotAst) - def dotAst(astRoot: AstNode): String = { - val ast = new AstGenerator().generate(astRoot) - DotSerializer.dotGraph(Option(astRoot), ast) - } - -} + def dotAst(astRoot: AstNode): String = + val ast = new AstGenerator().generate(astRoot) + DotSerializer.dotGraph(Option(astRoot), ast) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala index 21c87ea0..33932599 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala @@ -2,11 +2,8 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.Cpg -object DotCallGraphGenerator { +object DotCallGraphGenerator: - def dotCallGraph(cpg: Cpg): Iterator[String] = { - val callGraph = new CallGraphGenerator().generate(cpg) - Iterator(DotSerializer.dotGraph(None, callGraph)) - } - -} + def dotCallGraph(cpg: Cpg): Iterator[String] = + val callGraph = new CallGraphGenerator().generate(cpg) + Iterator(DotSerializer.dotGraph(None, callGraph)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala index ae604e75..833aedb3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala @@ -2,14 +2,11 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.nodes.Method -object DotCdgGenerator { +object DotCdgGenerator: - def dotCdg(traversal: Iterator[Method]): Iterator[String] = - traversal.map(dotCdg) + def dotCdg(traversal: Iterator[Method]): Iterator[String] = + traversal.map(dotCdg) - def dotCdg(method: Method): String = { - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), cdg) - } - -} + def dotCdg(method: Method): String = + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), cdg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala index 1eb83057..e9f15e87 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala @@ -2,14 +2,11 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.generated.nodes.Method -object DotCfgGenerator { +object DotCfgGenerator: - def dotCfg(traversal: Iterator[Method]): Iterator[String] = - traversal.map(dotCfg) + def dotCfg(traversal: Iterator[Method]): Iterator[String] = + traversal.map(dotCfg) - def dotCfg(method: Method): String = { - val cfg = new CfgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), cfg) - } - -} + def dotCfg(method: Method): String = + val cfg = new CfgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), cfg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala index bfa96be9..a5798064 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala @@ -10,149 +10,139 @@ import scala.collection.immutable.HashMap import scala.collection.mutable import scala.language.postfixOps -object DotSerializer { - - private val charLimit = 50 - - case class Graph( - vertices: List[StoredNode], - edges: List[Edge], - subgraph: Map[String, Seq[StoredNode]] = HashMap.empty[String, Seq[StoredNode]] - ) { - - def ++(other: Graph): Graph = { - Graph((this.vertices ++ other.vertices).distinct, (this.edges ++ other.edges).distinct) - } - - } - case class Edge( - src: StoredNode, - dst: StoredNode, - srcVisible: Boolean = true, - label: String = "", - edgeType: String = "" - ) - - def dotGraph(root: Option[AstNode] = None, graph: Graph, withEdgeTypes: Boolean = false): String = { - val sb = root match { - case Some(r) => namedGraphBegin(r) - case None => defaultGraphBegin() - } - val nodeStrings = graph.vertices.map(nodeToDot) - val edgeStrings = graph.edges.map(e => edgeToDot(e, withEdgeTypes)) - val subgraphStrings = graph.subgraph.zipWithIndex.map { case ((subgraph, nodes), idx) => - nodesToSubGraphs(subgraph, nodes, idx) - } - sb.append((nodeStrings ++ edgeStrings ++ subgraphStrings).mkString("\n")) - graphEnd(sb) - } - - private def namedGraphBegin(root: AstNode): mutable.StringBuilder = { - val sb = new mutable.StringBuilder - val name = escape(root match { - case method: Method => method.name - case _ => "" - }) - sb.append(s"""digraph "$name" { \n""") - } - - private def defaultGraphBegin(): mutable.StringBuilder = { - val sb = new mutable.StringBuilder - val name = "CPG" - sb.append(s"""digraph "$name" { \n""") - } - - private def limit(str: String): String = if (str.length > charLimit) { - s"${str.take(charLimit - 3)}..." - } else { - str - } - - private def stringRepr(vertex: StoredNode): String = { - val maybeLineNo: Optional[AnyRef] = vertex.propertyOption(PropertyNames.LINE_NUMBER) - escape(vertex match { - case call: Call => (call.name, limit(call.code)).toString - case contrl: ControlStructure => (contrl.label, contrl.controlStructureType, contrl.code).toString - case expr: Expression => (expr.label, limit(expr.code), limit(toCfgNode(expr).code)).toString - case method: Method => (method.label, method.name).toString - case ret: MethodReturn => (ret.label, ret.typeFullName).toString - case param: MethodParameterIn => ("PARAM", param.code).toString - case local: Local => (local.label, s"${local.code}: ${local.typeFullName}").toString - case target: JumpTarget => (target.label, target.name).toString - case modifier: Modifier => (modifier.label, modifier.modifierType).toString() - case annoAssign: AnnotationParameterAssign => (annoAssign.label, annoAssign.code).toString() - case annoParam: AnnotationParameter => (annoParam.label, annoParam.code).toString() - case typ: Type => (typ.label, typ.name).toString() - case typeDecl: TypeDecl => (typeDecl.label, typeDecl.name).toString() - case member: Member => (member.label, member.name).toString() - case _ => "" - }) + (if (maybeLineNo.isPresent) s"${maybeLineNo.get()}" else "") - } - - private def toCfgNode(node: StoredNode): CfgNode = { - node match { - case node: Identifier => node.parentExpression.get - case node: MethodRef => node.parentExpression.get - case node: Literal => node.parentExpression.get - case node: MethodParameterIn => node.method - case node: MethodParameterOut => node.method.methodReturn - case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => - node.parentExpression.get - case node: CallRepr => node - case node: MethodReturn => node - case node: Expression => node - } - } - - private def nodeToDot(node: StoredNode): String = { - s""""${node.id}" [label = <${stringRepr(node)}> ]""".stripMargin - } - - private def edgeToDot(edge: Edge, withEdgeTypes: Boolean): String = { - val edgeLabel = if (withEdgeTypes) { - edge.edgeType + ": " + escape(edge.label) - } else { - escape(edge.label) - } - val labelStr = Some(s""" [ label = "$edgeLabel"] """).filter(_ => edgeLabel != "").getOrElse("") - s""" "${edge.src.id}" -> "${edge.dst.id}" """ + labelStr - } - - def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = { - val escapedName = escape(subgraph) - val childString = children.map { c => s" \"${c.id()}\";" }.mkString("\n") - s""" subgraph cluster_$idx { +object DotSerializer: + + private val charLimit = 50 + + case class Graph( + vertices: List[StoredNode], + edges: List[Edge], + subgraph: Map[String, Seq[StoredNode]] = HashMap.empty[String, Seq[StoredNode]] + ): + + def ++(other: Graph): Graph = + Graph((this.vertices ++ other.vertices).distinct, (this.edges ++ other.edges).distinct) + + case class Edge( + src: StoredNode, + dst: StoredNode, + srcVisible: Boolean = true, + label: String = "", + edgeType: String = "" + ) + + def dotGraph( + root: Option[AstNode] = None, + graph: Graph, + withEdgeTypes: Boolean = false + ): String = + val sb = root match + case Some(r) => namedGraphBegin(r) + case None => defaultGraphBegin() + val nodeStrings = graph.vertices.map(nodeToDot) + val edgeStrings = graph.edges.map(e => edgeToDot(e, withEdgeTypes)) + val subgraphStrings = graph.subgraph.zipWithIndex.map { case ((subgraph, nodes), idx) => + nodesToSubGraphs(subgraph, nodes, idx) + } + sb.append((nodeStrings ++ edgeStrings ++ subgraphStrings).mkString("\n")) + graphEnd(sb) + + private def namedGraphBegin(root: AstNode): mutable.StringBuilder = + val sb = new mutable.StringBuilder + val name = escape(root match + case method: Method => method.name + case _ => "" + ) + sb.append(s"""digraph "$name" { \n""") + + private def defaultGraphBegin(): mutable.StringBuilder = + val sb = new mutable.StringBuilder + val name = "CPG" + sb.append(s"""digraph "$name" { \n""") + + private def limit(str: String): String = if str.length > charLimit then + s"${str.take(charLimit - 3)}..." + else + str + + private def stringRepr(vertex: StoredNode): String = + val maybeLineNo: Optional[AnyRef] = vertex.propertyOption(PropertyNames.LINE_NUMBER) + escape(vertex match + case call: Call => (call.name, limit(call.code)).toString + case contrl: ControlStructure => + (contrl.label, contrl.controlStructureType, contrl.code).toString + case expr: Expression => + (expr.label, limit(expr.code), limit(toCfgNode(expr).code)).toString + case method: Method => (method.label, method.name).toString + case ret: MethodReturn => (ret.label, ret.typeFullName).toString + case param: MethodParameterIn => ("PARAM", param.code).toString + case local: Local => (local.label, s"${local.code}: ${local.typeFullName}").toString + case target: JumpTarget => (target.label, target.name).toString + case modifier: Modifier => (modifier.label, modifier.modifierType).toString() + case annoAssign: AnnotationParameterAssign => + (annoAssign.label, annoAssign.code).toString() + case annoParam: AnnotationParameter => (annoParam.label, annoParam.code).toString() + case typ: Type => (typ.label, typ.name).toString() + case typeDecl: TypeDecl => (typeDecl.label, typeDecl.name).toString() + case member: Member => (member.label, member.name).toString() + case _ => "" + ) + (if maybeLineNo.isPresent then s"${maybeLineNo.get()}" else "") + end stringRepr + + private def toCfgNode(node: StoredNode): CfgNode = + node match + case node: Identifier => node.parentExpression.get + case node: MethodRef => node.parentExpression.get + case node: Literal => node.parentExpression.get + case node: MethodParameterIn => node.method + case node: MethodParameterOut => node.method.methodReturn + case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => + node.parentExpression.get + case node: CallRepr => node + case node: MethodReturn => node + case node: Expression => node + + private def nodeToDot(node: StoredNode): String = + s""""${node.id}" [label = <${stringRepr(node)}> ]""".stripMargin + + private def edgeToDot(edge: Edge, withEdgeTypes: Boolean): String = + val edgeLabel = if withEdgeTypes then + edge.edgeType + ": " + escape(edge.label) + else + escape(edge.label) + val labelStr = + Some(s""" [ label = "$edgeLabel"] """).filter(_ => edgeLabel != "").getOrElse("") + s""" "${edge.src.id}" -> "${edge.dst.id}" """ + labelStr + + def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = + val escapedName = escape(subgraph) + val childString = children.map { c => s" \"${c.id()}\";" }.mkString("\n") + s""" subgraph cluster_$idx { |$childString | label = "$escapedName"; | } |""".stripMargin - } - - /** Escapes common characters that do not conform to HTML character sets. - * @see - * https://www.w3.org/TR/html4/sgml/entities.html - */ - private def escapedChar(ch: Char): String = ch match { - case '"' => """ - case '<' => "<" - case '>' => ">" - case '&' => "&" - case _ => - if (ch.isControl) "\\0" + Integer.toOctalString(ch.toInt) - else String.valueOf(ch) - } - - private def escape(str: String): String = { - if (str == null) { - "" - } else { - str.flatMap(escapedChar) - } - } - - private def graphEnd(sb: mutable.StringBuilder): String = { - sb.append("\n}\n") - sb.toString - } - -} + + /** Escapes common characters that do not conform to HTML character sets. + * @see + * https://www.w3.org/TR/html4/sgml/entities.html + */ + private def escapedChar(ch: Char): String = ch match + case '"' => """ + case '<' => "<" + case '>' => ">" + case '&' => "&" + case _ => + if ch.isControl then "\\0" + Integer.toOctalString(ch.toInt) + else String.valueOf(ch) + + private def escape(str: String): String = + if str == null then + "" + else + str.flatMap(escapedChar) + + private def graphEnd(sb: mutable.StringBuilder): String = + sb.append("\n}\n") + sb.toString +end DotSerializer diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala index ca07348c..aa3bfa9f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala @@ -2,11 +2,8 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.Cpg -object DotTypeHierarchyGenerator { +object DotTypeHierarchyGenerator: - def dotTypeHierarchy(cpg: Cpg): Iterator[String] = { - val typeHierarchy = new TypeHierarchyGenerator().generate(cpg) - Iterator(DotSerializer.dotGraph(None, typeHierarchy)) - } - -} + def dotTypeHierarchy(cpg: Cpg): Iterator[String] = + val typeHierarchy = new TypeHierarchyGenerator().generate(cpg) + Iterator(DotSerializer.dotGraph(None, typeHierarchy)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala index e3097cb9..bdd1dc1e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala @@ -3,47 +3,43 @@ package io.shiftleft.semanticcpg.dotgenerator import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{StoredNode, Type, TypeDecl} import io.shiftleft.semanticcpg.dotgenerator.DotSerializer.{Edge, Graph} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import scala.collection.mutable -class TypeHierarchyGenerator { +class TypeHierarchyGenerator: - def generate(cpg: Cpg): Graph = { - val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] - val vertices = cpg.typeDecl.l - val typeToIsExternal = vertices.map { t => t.fullName -> t.isExternal }.toMap - val edges = for { - srcTypeDecl <- vertices - srcType <- srcTypeDecl._typeViaRefIn.l - _ = storeInSubgraph(srcType, subgraph, typeToIsExternal) - tgtType <- srcTypeDecl.inheritsFromOut - } yield { - storeInSubgraph(tgtType, subgraph, typeToIsExternal) - Edge(tgtType, srcType) - } - Graph(vertices.flatMap(_._typeViaRefIn.l), edges.distinct, subgraph.toMap) - } + def generate(cpg: Cpg): Graph = + val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] + val vertices = cpg.typeDecl.l + val typeToIsExternal = vertices.map { t => t.fullName -> t.isExternal }.toMap + val edges = + for + srcTypeDecl <- vertices + srcType <- srcTypeDecl._typeViaRefIn.l + _ = storeInSubgraph(srcType, subgraph, typeToIsExternal) + tgtType <- srcTypeDecl.inheritsFromOut + yield + storeInSubgraph(tgtType, subgraph, typeToIsExternal) + Edge(tgtType, srcType) + Graph(vertices.flatMap(_._typeViaRefIn.l), edges.distinct, subgraph.toMap) - def storeInSubgraph( - typ: Type, - subgraph: mutable.Map[String, Seq[StoredNode]], - typeToIsExternal: Map[String, Boolean] - ): Unit = { - if (!typeToIsExternal(typ.fullName)) { - /* + def storeInSubgraph( + typ: Type, + subgraph: mutable.Map[String, Seq[StoredNode]], + typeToIsExternal: Map[String, Boolean] + ): Unit = + if !typeToIsExternal(typ.fullName) then + /* We parse the namespace information instead of looking at the namespace node as types such as inner classes may not be attached to a namespace block - */ - val namespace = - if (typ.fullName.contains(".")) - typ.fullName.stripSuffix(s".${typ.name}") + */ + val namespace = + if typ.fullName.contains(".") then + typ.fullName.stripSuffix(s".${typ.name}") + else + typ.fullName.stripSuffix(s"${typ.name}") + subgraph.put(namespace, subgraph.getOrElse(namespace, Seq()) ++ Seq(typ)) else - typ.fullName.stripSuffix(s"${typ.name}") - subgraph.put(namespace, subgraph.getOrElse(namespace, Seq()) ++ Seq(typ)) - } else { - subgraph.put("", subgraph.getOrElse("", Seq()) ++ Seq(typ)) - } - } - -} + subgraph.put("", subgraph.getOrElse("", Seq()) ++ Seq(typ)) +end TypeHierarchyGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala index 0d34c1f5..743bd805 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala @@ -1,149 +1,142 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.{Operators, Properties, PropertyNames} -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.accesspath._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.accesspath.* import org.slf4j.LoggerFactory import scala.jdk.CollectionConverters.IteratorHasAsScala -object AccessPathHandling { +object AccessPathHandling: - def leafToTrackedBaseAndAccessPathInternal(node: StoredNode): Option[(TrackedBase, List[AccessElement])] = { - node match { - case node: MethodParameterIn => Some((TrackedNamedVariable(node.name), Nil)) - case node: MethodParameterOut => Some((TrackedNamedVariable(node.name), Nil)) - case node: Identifier => Some((TrackedNamedVariable(node.name), Nil)) - case node: Literal => Some((TrackedLiteral(node), Nil)) - case node: MethodRef => Some((TrackedMethod(node), Nil)) - case node: TypeRef => Some((TrackedTypeRef(node), Nil)) - case _: Return => Some((TrackedFormalReturn, Nil)) - case _: MethodReturn => Some((TrackedFormalReturn, Nil)) - case _: Unknown => Some((TrackedUnknown, Nil)) - case _: ControlStructure => Some((TrackedUnknown, Nil)) - // FieldIdentifiers are only fake arguments, hence should not be tracked - case _: FieldIdentifier => Some((TrackedUnknown, Nil)) - case _ => None - } - } + def leafToTrackedBaseAndAccessPathInternal(node: StoredNode) + : Option[(TrackedBase, List[AccessElement])] = + node match + case node: MethodParameterIn => Some((TrackedNamedVariable(node.name), Nil)) + case node: MethodParameterOut => Some((TrackedNamedVariable(node.name), Nil)) + case node: Identifier => Some((TrackedNamedVariable(node.name), Nil)) + case node: Literal => Some((TrackedLiteral(node), Nil)) + case node: MethodRef => Some((TrackedMethod(node), Nil)) + case node: TypeRef => Some((TrackedTypeRef(node), Nil)) + case _: Return => Some((TrackedFormalReturn, Nil)) + case _: MethodReturn => Some((TrackedFormalReturn, Nil)) + case _: Unknown => Some((TrackedUnknown, Nil)) + case _: ControlStructure => Some((TrackedUnknown, Nil)) + // FieldIdentifiers are only fake arguments, hence should not be tracked + case _: FieldIdentifier => Some((TrackedUnknown, Nil)) + case _ => None - private val logger = LoggerFactory.getLogger(getClass) - private var hasWarnedDeprecations = false + private val logger = LoggerFactory.getLogger(getClass) + private var hasWarnedDeprecations = false - def memberAccessToPath(memberAccess: Call, tail: List[AccessElement]): List[AccessElement] = { - memberAccess.name match { - case Operators.memberAccess | Operators.indirectMemberAccess => - if (!hasWarnedDeprecations) { - logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") - hasWarnedDeprecations = true - } - memberAccess - .argumentOption(2) - .collect { - case node: Literal => ConstantAccess(node.code) - case node: Identifier => ConstantAccess(node.name) - case other if other.propertyOption(PropertyNames.NAME).isPresent => - logger.warn(s"unexpected/deprecated node encountered: $other with properties: ${other.propertiesMap()}") - ConstantAccess(other.property(Properties.NAME)) - } - .getOrElse(VariableAccess) :: tail + def memberAccessToPath(memberAccess: Call, tail: List[AccessElement]): List[AccessElement] = + memberAccess.name match + case Operators.memberAccess | Operators.indirectMemberAccess => + if !hasWarnedDeprecations then + logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") + hasWarnedDeprecations = true + memberAccess + .argumentOption(2) + .collect { + case node: Literal => ConstantAccess(node.code) + case node: Identifier => ConstantAccess(node.name) + case other if other.propertyOption(PropertyNames.NAME).isPresent => + logger.warn( + s"unexpected/deprecated node encountered: $other with properties: ${other.propertiesMap()}" + ) + ConstantAccess(other.property(Properties.NAME)) + } + .getOrElse(VariableAccess) :: tail - case Operators.computedMemberAccess | Operators.indirectComputedMemberAccess => - if (!hasWarnedDeprecations) { - logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") - hasWarnedDeprecations = true - } - memberAccess - .argumentOption(2) - .collect { case lit: Literal => - ConstantAccess(lit.code) - } - .getOrElse(VariableAccess) :: tail - case Operators.indirection => - IndirectionAccess :: tail - case Operators.addressOf => - AddressOf :: tail - case Operators.fieldAccess => - extractAccessStringTokenForFieldAccess(memberAccess) :: tail - case Operators.indexAccess => - extractAccessStringToken(memberAccess) :: tail - case Operators.indirectFieldAccess => - // we will reverse the list in the end - extractAccessStringTokenForFieldAccess(memberAccess) :: IndirectionAccess :: tail - case Operators.indirectIndexAccess => - // we will reverse the list in the end - IndirectionAccess :: extractAccessIntToken(memberAccess) :: tail - case Operators.pointerShift => - extractAccessIntToken(memberAccess) :: tail - case Operators.getElementPtr => - // we will reverse the list in the end - AddressOf :: extractAccessStringTokenForFieldAccess(memberAccess) :: IndirectionAccess :: tail - } - } + case Operators.computedMemberAccess | Operators.indirectComputedMemberAccess => + if !hasWarnedDeprecations then + logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") + hasWarnedDeprecations = true + memberAccess + .argumentOption(2) + .collect { case lit: Literal => + ConstantAccess(lit.code) + } + .getOrElse(VariableAccess) :: tail + case Operators.indirection => + IndirectionAccess :: tail + case Operators.addressOf => + AddressOf :: tail + case Operators.fieldAccess => + extractAccessStringTokenForFieldAccess(memberAccess) :: tail + case Operators.indexAccess => + extractAccessStringToken(memberAccess) :: tail + case Operators.indirectFieldAccess => + // we will reverse the list in the end + extractAccessStringTokenForFieldAccess(memberAccess) :: IndirectionAccess :: tail + case Operators.indirectIndexAccess => + // we will reverse the list in the end + IndirectionAccess :: extractAccessIntToken(memberAccess) :: tail + case Operators.pointerShift => + extractAccessIntToken(memberAccess) :: tail + case Operators.getElementPtr => + // we will reverse the list in the end + AddressOf :: extractAccessStringTokenForFieldAccess( + memberAccess + ) :: IndirectionAccess :: tail - private def extractAccessStringTokenForFieldAccess(memberAccess: Call): AccessElement = { - memberAccess.argumentOption(2) match { - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariableAccess - case Some(literal: Literal) => ConstantAccess(literal.code) - case Some(fieldIdentifier: FieldIdentifier) => - ConstantAccess(fieldIdentifier.canonicalName) - case Some(identifier: Identifier) => - // TODO remove this case. - // This is handling a very old CPG format version where IDENTIFIER was used instead of FIELD_IDENTIFIER. - // Sadly we need this for now to support a GO cpg. - ConstantAccess(identifier.name) - case _ => VariableAccess - } - } + private def extractAccessStringTokenForFieldAccess(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariableAccess + case Some(literal: Literal) => ConstantAccess(literal.code) + case Some(fieldIdentifier: FieldIdentifier) => + ConstantAccess(fieldIdentifier.canonicalName) + case Some(identifier: Identifier) => + // TODO remove this case. + // This is handling a very old CPG format version where IDENTIFIER was used instead of FIELD_IDENTIFIER. + // Sadly we need this for now to support a GO cpg. + ConstantAccess(identifier.name) + case _ => VariableAccess - private def extractAccessStringToken(memberAccess: Call): AccessElement = { - memberAccess.argumentOption(2) match { - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariableAccess - case Some(literal: Literal) => ConstantAccess(literal.code) - case Some(fieldIdentifier: FieldIdentifier) => - ConstantAccess(fieldIdentifier.canonicalName) - case _ => VariableAccess - } - } + private def extractAccessStringToken(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariableAccess + case Some(literal: Literal) => ConstantAccess(literal.code) + case Some(fieldIdentifier: FieldIdentifier) => + ConstantAccess(fieldIdentifier.canonicalName) + case _ => VariableAccess - private def extractAccessIntToken(memberAccess: Call): AccessElement = { - memberAccess.argumentOption(2) match { - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariablePointerShift - case Some(literal: Literal) => - literal.code.toIntOption.map(PointerShift.apply).getOrElse(VariablePointerShift) - case Some(fieldIdentifier: FieldIdentifier) => - fieldIdentifier.canonicalName.toIntOption - .map(PointerShift.apply) - .getOrElse(VariablePointerShift) - case _ => VariablePointerShift - } - } + private def extractAccessIntToken(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariablePointerShift + case Some(literal: Literal) => + literal.code.toIntOption.map(PointerShift.apply).getOrElse(VariablePointerShift) + case Some(fieldIdentifier: FieldIdentifier) => + fieldIdentifier.canonicalName.toIntOption + .map(PointerShift.apply) + .getOrElse(VariablePointerShift) + case _ => VariablePointerShift - def lastExpressionInBlock(block: Block): Option[Expression] = - block._astOut - .collect { - case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => node - } - .toVector - .sortBy(_.order) - .lastOption - -} + def lastExpressionInBlock(block: Block): Option[Expression] = + block._astOut + .collect { + case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => + node + } + .toVector + .sortBy(_.order) + .lastOption +end AccessPathHandling diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala index 6c111f2f..ba31a3df 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala @@ -2,6 +2,5 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.NewLocation -trait HasLocation extends Any { - def location: NewLocation -} +trait HasLocation extends Any: + def location: NewLocation diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala index 4f2322bf..83e0d12e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala @@ -5,85 +5,77 @@ import io.shiftleft.codepropertygraph.generated.nodes.{CallRepr, Method} import scala.collection.mutable import scala.jdk.CollectionConverters.* -trait ICallResolver { - - def getUnresolvedMethodFullNames(callsite: CallRepr): Iterable[String] = { - triggerCallsiteResolution(callsite) - getUnresolvedMethodFullNamesInternal(callsite) - } - - def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] - - /** Get methods called at the given callsite. This internally calls triggerCallsiteResolution. - */ - def getCalledMethods(callsite: CallRepr): Iterable[Method] = { - val combined = mutable.ArrayBuffer.empty[Method] - if (callsite.nonEmpty) { - triggerCallsiteResolution(callsite) - callsite._callOut.foreach(method => combined.append(method.asInstanceOf[Method])) - combined.appendAll(getResolvedCalledMethods(callsite)) - } - combined - } - - /** Same as getCalledMethods but with traversal return type. - */ - def getCalledMethodsAsTraversal(callsite: CallRepr): Iterator[Method] = - getCalledMethods(callsite).iterator - - /** Get callsites of the given method. This internally calls triggerMethodResolution. - */ - def getMethodCallsites(method: Method): Iterable[CallRepr] = { - triggerMethodCallsiteResolution(method) - // The same call sites of a method can be found via static and dynamic lookup. - // This is for example the case for Java virtual call sites which are statically assert - // a certain method which could be overriden. If we are looking for the call sites of - // such a statically asserted method, we find it twice and thus deduplicate here. - val combined = mutable.LinkedHashSet.empty[CallRepr] - method._callIn.foreach(call => combined.add(call.asInstanceOf[CallRepr])) - combined.addAll(getResolvedMethodCallsites(method)) - - combined.toBuffer - } - - /** Same as getMethodCallsites but with traversal return type. - */ - def getMethodCallsitesAsTraversal(method: Method): Iterator[CallRepr] = - getMethodCallsites(method).iterator - - /** Starts data flow tracking to find all method which could be called at the given callsite. The result is stored in - * the resolver internal cache. - */ - def triggerCallsiteResolution(callsite: CallRepr): Unit - - /** Starts data flow tracking to find all callsites which could call the given method. The result is stored in the - * resolver internal cache. - */ - def triggerMethodCallsiteResolution(method: Method): Unit - - /** Retrieve results of triggerCallsiteResolution. - */ - def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] - - /** Retrieve results of triggerMethodResolution. - */ - def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] -} - -object NoResolve extends ICallResolver { - def triggerCallsiteResolution(callsite: CallRepr): Unit = {} - - def triggerMethodCallsiteResolution(method: Method): Unit = {} - - override def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] = { - Iterable.empty - } - - override def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] = { - Iterable.empty - } - - override def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] = { - Iterable.empty - } -} +trait ICallResolver: + + def getUnresolvedMethodFullNames(callsite: CallRepr): Iterable[String] = + triggerCallsiteResolution(callsite) + getUnresolvedMethodFullNamesInternal(callsite) + + def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] + + /** Get methods called at the given callsite. This internally calls triggerCallsiteResolution. + */ + def getCalledMethods(callsite: CallRepr): Iterable[Method] = + val combined = mutable.ArrayBuffer.empty[Method] + if callsite.nonEmpty then + triggerCallsiteResolution(callsite) + callsite._callOut.foreach(method => combined.append(method.asInstanceOf[Method])) + combined.appendAll(getResolvedCalledMethods(callsite)) + combined + + /** Same as getCalledMethods but with traversal return type. + */ + def getCalledMethodsAsTraversal(callsite: CallRepr): Iterator[Method] = + getCalledMethods(callsite).iterator + + /** Get callsites of the given method. This internally calls triggerMethodResolution. + */ + def getMethodCallsites(method: Method): Iterable[CallRepr] = + triggerMethodCallsiteResolution(method) + // The same call sites of a method can be found via static and dynamic lookup. + // This is for example the case for Java virtual call sites which are statically assert + // a certain method which could be overriden. If we are looking for the call sites of + // such a statically asserted method, we find it twice and thus deduplicate here. + val combined = mutable.LinkedHashSet.empty[CallRepr] + method._callIn.foreach(call => combined.add(call.asInstanceOf[CallRepr])) + combined.addAll(getResolvedMethodCallsites(method)) + + combined.toBuffer + + /** Same as getMethodCallsites but with traversal return type. + */ + def getMethodCallsitesAsTraversal(method: Method): Iterator[CallRepr] = + getMethodCallsites(method).iterator + + /** Starts data flow tracking to find all method which could be called at the given callsite. + * The result is stored in the resolver internal cache. + */ + def triggerCallsiteResolution(callsite: CallRepr): Unit + + /** Starts data flow tracking to find all callsites which could call the given method. The + * result is stored in the resolver internal cache. + */ + def triggerMethodCallsiteResolution(method: Method): Unit + + /** Retrieve results of triggerCallsiteResolution. + */ + def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] + + /** Retrieve results of triggerMethodResolution. + */ + def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] +end ICallResolver + +object NoResolve extends ICallResolver: + def triggerCallsiteResolution(callsite: CallRepr): Unit = {} + + def triggerMethodCallsiteResolution(method: Method): Unit = {} + + override def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] = + Iterable.empty + + override def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] = + Iterable.empty + + override def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] = + Iterable.empty diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala index 81f4d44a..1e314f9a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala @@ -1,81 +1,77 @@ package io.shiftleft.semanticcpg.language -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import org.slf4j.{Logger, LoggerFactory} -import overflowdb.traversal._ +import overflowdb.traversal.* import scala.annotation.tailrec /* TODO MP: this should be part of the normal steps, rather than matching on the type at runtime * all (and only) steps extending DataFlowObject should/must have `newSink`, `newSource` and `newLocation` */ -object LocationCreator { +object LocationCreator: - private val logger: Logger = LoggerFactory.getLogger(getClass) + private val logger: Logger = LoggerFactory.getLogger(getClass) - def apply(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = { - try { - location(node) - } catch { - case exc @ (_: NoSuchElementException | _: ClassCastException) => - logger.debug(s"Cannot determine location for ${node.label} due to broken CPG", exc) - emptyLocation(node.label, Some(node)) - } - } + def apply(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = + try + location(node) + catch + case exc @ (_: NoSuchElementException | _: ClassCastException) => + logger.debug(s"Cannot determine location for ${node.label} due to broken CPG", exc) + emptyLocation(node.label, Some(node)) - private def location(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = { - finder(node) match { - case Some(n: HasLocation) => n.location - case _ => LocationCreator.emptyLocation("", None) - } - } + private def location(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = + finder(node) match + case Some(n: HasLocation) => n.location + case _ => LocationCreator.emptyLocation("", None) - def apply( - node: StoredNode, - symbol: String, - label: String, - lineNumber: Option[Integer], - method: Method - ): NewLocation = { + def apply( + node: StoredNode, + symbol: String, + label: String, + lineNumber: Option[Integer], + method: Method + ): NewLocation = + if method == null then + NewLocation().node(node) + else + val typeOption = methodToTypeDecl(method) + val typeName = typeOption.map(_.fullName).getOrElse("") + val typeShortName = typeOption.map(_.name).getOrElse("") - if (method == null) { - NewLocation().node(node) - } else { - val typeOption = methodToTypeDecl(method) - val typeName = typeOption.map(_.fullName).getOrElse("") - val typeShortName = typeOption.map(_.name).getOrElse("") + val namespaceOption = + for + tpe <- typeOption + namespaceBlock <- tpe.namespaceBlock + namespace <- namespaceBlock._namespaceViaRefOut.nextOption() + yield namespace.name + val namespaceName = namespaceOption.getOrElse("") - val namespaceOption = for { - tpe <- typeOption - namespaceBlock <- tpe.namespaceBlock - namespace <- namespaceBlock._namespaceViaRefOut.nextOption() - } yield namespace.name - val namespaceName = namespaceOption.getOrElse("") + NewLocation() + .symbol(symbol) + .methodFullName(method.fullName) + .methodShortName(method.name) + .packageName(namespaceName) + .lineNumber(lineNumber) + .className(typeName) + .classShortName(typeShortName) + .nodeLabel(label) + .filename(if method.filename.isEmpty then "N/A" else method.filename) + .node(node) - NewLocation() - .symbol(symbol) - .methodFullName(method.fullName) - .methodShortName(method.name) - .packageName(namespaceName) - .lineNumber(lineNumber) - .className(typeName) - .classShortName(typeShortName) - .nodeLabel(label) - .filename(if (method.filename.isEmpty) "N/A" else method.filename) - .node(node) - } - } + private def methodToTypeDecl(method: Method): Option[TypeDecl] = + findVertex(method, _.isInstanceOf[TypeDecl]).map(_.asInstanceOf[TypeDecl]) - private def methodToTypeDecl(method: Method): Option[TypeDecl] = - findVertex(method, _.isInstanceOf[TypeDecl]).map(_.asInstanceOf[TypeDecl]) + @tailrec + private def findVertex( + node: StoredNode, + instanceCheck: StoredNode => Boolean + ): Option[StoredNode] = + node._astIn.nextOption() match + case Some(head) if instanceCheck(head) => Some(head) + case Some(head) => findVertex(head, instanceCheck) + case None => None - @tailrec - private def findVertex(node: StoredNode, instanceCheck: StoredNode => Boolean): Option[StoredNode] = - node._astIn.nextOption() match { - case Some(head) if instanceCheck(head) => Some(head) - case Some(head) => findVertex(head, instanceCheck) - case None => None - } - - def emptyLocation(label: String, node: Option[StoredNode]): NewLocation = - NewLocation().nodeLabel(label).node(node) -} + def emptyLocation(label: String, node: Option[StoredNode]): NewLocation = + NewLocation().nodeLabel(label).node(node) +end LocationCreator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala index 2a4384e9..e83cc720 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala @@ -3,18 +3,15 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.NewNode import overflowdb.BatchedUpdate.DiffGraphBuilder -trait HasStoreMethod { - def store()(implicit diffBuilder: DiffGraphBuilder): Unit -} +trait HasStoreMethod: + def store()(implicit diffBuilder: DiffGraphBuilder): Unit -class NewNodeSteps[A <: NewNode](val traversal: Iterator[A]) extends HasStoreMethod { +class NewNodeSteps[A <: NewNode](val traversal: Iterator[A]) extends HasStoreMethod: - override def store()(implicit diffBuilder: DiffGraphBuilder): Unit = - traversal.sideEffect(storeRecursively).iterate() + override def store()(implicit diffBuilder: DiffGraphBuilder): Unit = + traversal.sideEffect(storeRecursively).iterate() - private def storeRecursively(newNode: NewNode)(implicit diffBuilder: DiffGraphBuilder): Unit = { - diffBuilder.addNode(newNode) - } + private def storeRecursively(newNode: NewNode)(implicit diffBuilder: DiffGraphBuilder): Unit = + diffBuilder.addNode(newNode) - def label: Iterator[String] = traversal.map(_.label) -} + def label: Iterator[String] = traversal.map(_.label) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala index 9248f9e7..be2a258f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala @@ -4,18 +4,15 @@ import io.shiftleft.codepropertygraph.generated.EdgeTypes import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewTagNodePair, StoredNode} import overflowdb.BatchedUpdate.DiffGraphBuilder -class NewTagNodePairTraversal(traversal: Iterator[NewTagNodePair]) extends HasStoreMethod { - override def store()(implicit diffGraph: DiffGraphBuilder): Unit = { - traversal.foreach { tagNodePair => - val tag = tagNodePair.tag - val tagValue = tagNodePair.node - diffGraph.addNode(tag.asInstanceOf[NewNode]) - tagValue match { - case tagValue: StoredNode => - diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY) - case tagValue: NewNode => - diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY, Nil) - } - } - } -} +class NewTagNodePairTraversal(traversal: Iterator[NewTagNodePair]) extends HasStoreMethod: + override def store()(implicit diffGraph: DiffGraphBuilder): Unit = + traversal.foreach { tagNodePair => + val tag = tagNodePair.tag + val tagValue = tagNodePair.node + diffGraph.addNode(tag.asInstanceOf[NewNode]) + tagValue match + case tagValue: StoredNode => + diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY) + case tagValue: NewNode => + diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY, Nil) + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala index 6c7c4024..299a0236 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala @@ -1,47 +1,43 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.{ - Call, - Identifier, - Literal, - Local, - Method, - MethodParameterIn, - MethodParameterOut, - MethodRef, - MethodReturn, - StoredNode + Call, + Identifier, + Literal, + Local, + Method, + MethodParameterIn, + MethodParameterOut, + MethodRef, + MethodReturn, + StoredNode } import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.nodemethods.{ - CallMethods, - IdentifierMethods, - LiteralMethods, - LocalMethods, - MethodMethods, - MethodParameterInMethods, - MethodParameterOutMethods, - MethodRefMethods, - MethodReturnMethods + CallMethods, + IdentifierMethods, + LiteralMethods, + LocalMethods, + MethodMethods, + MethodParameterInMethods, + MethodParameterOutMethods, + MethodRefMethods, + MethodReturnMethods } -trait NodeExtensionFinder { - def apply(n: StoredNode): Option[NodeExtension] -} +trait NodeExtensionFinder: + def apply(n: StoredNode): Option[NodeExtension] -object DefaultNodeExtensionFinder extends NodeExtensionFinder { - override def apply(node: StoredNode): Option[NodeExtension] = { - node match { - case n: Method => Some(new MethodMethods(n)) - case n: MethodParameterIn => Some(new MethodParameterInMethods(n)) - case n: MethodParameterOut => Some(new MethodParameterOutMethods(n)) - case n: MethodReturn => Some(new MethodReturnMethods(n)) - case n: Call => Some(new CallMethods(n)) - case n: Identifier => Some(new IdentifierMethods(n)) - case n: Literal => Some(new LiteralMethods(n)) - case n: Local => Some(new LocalMethods(n)) - case n: MethodRef => Some(new MethodRefMethods(n)) - case _ => None - } - } -} +object DefaultNodeExtensionFinder extends NodeExtensionFinder: + override def apply(node: StoredNode): Option[NodeExtension] = + node match + case n: Method => Some(new MethodMethods(n)) + case n: MethodParameterIn => Some(new MethodParameterInMethods(n)) + case n: MethodParameterOut => Some(new MethodParameterOutMethods(n)) + case n: MethodReturn => Some(new MethodReturnMethods(n)) + case n: Call => Some(new CallMethods(n)) + case n: Identifier => Some(new IdentifierMethods(n)) + case n: Literal => Some(new LiteralMethods(n)) + case n: Local => Some(new LocalMethods(n)) + case n: MethodRef => Some(new MethodRefMethods(n)) + case _ => None diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala index de995be0..f0e13235 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala @@ -2,53 +2,48 @@ package io.shiftleft.semanticcpg.language import scala.collection.mutable -object NodeOrdering { - - /** For a given CFG with the entry node `cfgEntry` and an expansion function `expand`, return a map that associates - * each node with an index such that nodes are numbered in post order. - */ - def postOrderNumbering[NodeType]( - cfgEntry: NodeType, - expand: NodeType => Iterator[NodeType] - ): mutable.LinkedHashMap[NodeType, Int] = { - var stack = (cfgEntry, expand(cfgEntry)) :: Nil - val visited = mutable.Set.empty[NodeType] - val numbering = mutable.LinkedHashMap.empty[NodeType, Int] - var nextNumber = 0 - - while (stack.nonEmpty) { - val (node, successors) = stack.head - visited += node - - if (successors.hasNext) { - val successor = successors.next() - if (!visited.contains(successor)) { - stack = (successor, expand(successor)) :: stack - } - } else { - stack = stack.tail - numbering.put(node, nextNumber) - nextNumber += 1 - } - } - numbering - } - - /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes according to number in - * reverse order. - */ - def reverseNodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = { - nodeNumberPairs - .sortBy { case (_, num) => -num } - .map { case (node, _) => node } - } - - /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes according to number. - */ - def nodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = { - nodeNumberPairs - .sortBy { case (_, num) => num } - .map { case (node, _) => node } - } - -} +object NodeOrdering: + + /** For a given CFG with the entry node `cfgEntry` and an expansion function `expand`, return a + * map that associates each node with an index such that nodes are numbered in post order. + */ + def postOrderNumbering[NodeType]( + cfgEntry: NodeType, + expand: NodeType => Iterator[NodeType] + ): mutable.LinkedHashMap[NodeType, Int] = + var stack = (cfgEntry, expand(cfgEntry)) :: Nil + val visited = mutable.Set.empty[NodeType] + val numbering = mutable.LinkedHashMap.empty[NodeType, Int] + var nextNumber = 0 + + while stack.nonEmpty do + val (node, successors) = stack.head + visited += node + + if successors.hasNext then + val successor = successors.next() + if !visited.contains(successor) then + stack = (successor, expand(successor)) :: stack + else + stack = stack.tail + numbering.put(node, nextNumber) + nextNumber += 1 + numbering + end postOrderNumbering + + /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes + * according to number in reverse order. + */ + def reverseNodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = + nodeNumberPairs + .sortBy { case (_, num) => -num } + .map { case (node, _) => node } + + /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes + * according to number. + */ + def nodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = + nodeNumberPairs + .sortBy { case (_, num) => num } + .map { case (node, _) => node } +end NodeOrdering diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala index f6ab5321..7aac3442 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala @@ -1,11 +1,11 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.semanticcpg.codedumper.CodeDumper import overflowdb.Node -import overflowdb.traversal._ +import overflowdb.traversal.* import overflowdb.traversal.help.Doc /** Steps for all node types @@ -13,29 +13,31 @@ import overflowdb.traversal.help.Doc * This is the base class for all steps defined on */ @help.Traversal(elementType = classOf[StoredNode]) -class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) extends AnyVal { +class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) extends AnyVal: - @Doc( - info = "The source file this code is in", - longInfo = """ + @Doc( + info = "The source file this code is in", + longInfo = """ |Not all but most node in the graph can be associated with |a specific source file they appear in. `file` provides |the file node that represents that source file. |""" - ) - def file: Iterator[File] = - traversal - .choose(_.label) { - case NodeTypes.NAMESPACE => _.in(EdgeTypes.REF).out(EdgeTypes.SOURCE_FILE) - case NodeTypes.COMMENT => _.in(EdgeTypes.AST).hasLabel(NodeTypes.FILE) - case _ => - _.repeat(_.coalesce(_.out(EdgeTypes.SOURCE_FILE), _.in(EdgeTypes.AST)))(_.until(_.hasLabel(NodeTypes.FILE))) - } - .cast[File] + ) + def file: Iterator[File] = + traversal + .choose(_.label) { + case NodeTypes.NAMESPACE => _.in(EdgeTypes.REF).out(EdgeTypes.SOURCE_FILE) + case NodeTypes.COMMENT => _.in(EdgeTypes.AST).hasLabel(NodeTypes.FILE) + case _ => + _.repeat(_.coalesce(_.out(EdgeTypes.SOURCE_FILE), _.in(EdgeTypes.AST)))(_.until( + _.hasLabel(NodeTypes.FILE) + )) + } + .cast[File] - @Doc( - info = "Location, including filename and line number", - longInfo = """ + @Doc( + info = "Location, including filename and line number", + longInfo = """ |Most nodes of the graph can be associated with a specific |location in code, and `location` provides this location. |The return value is an object providing, e.g., filename, @@ -44,54 +46,53 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten |to the line number alone, without requiring any parsing |on the user's side. |""" - ) - def location(implicit finder: NodeExtensionFinder): Iterator[NewLocation] = - traversal.map(_.location) + ) + def location(implicit finder: NodeExtensionFinder): Iterator[NewLocation] = + traversal.map(_.location) - @Doc( - info = "Display code (with syntax highlighting)", - longInfo = """ + @Doc( + info = "Display code (with syntax highlighting)", + longInfo = """ |For methods, dump the method code. For expressions, |dump the method code along with an arrow pointing |to the expression. Uses ansi-color highlighting. |This only works for source frontends. |""" - ) - def dump(implicit finder: NodeExtensionFinder): List[String] = - _dump(highlight = true) + ) + def dump(implicit finder: NodeExtensionFinder): List[String] = + _dump(highlight = true) - @Doc( - info = "Display code (without syntax highlighting)", - longInfo = """ + @Doc( + info = "Display code (without syntax highlighting)", + longInfo = """ |For methods, dump the method code. For expressions, |dump the method code along with an arrow pointing |to the expression. No color highlighting. |""" - ) - def dumpRaw(implicit finder: NodeExtensionFinder): List[String] = - _dump(highlight = false) + ) + def dumpRaw(implicit finder: NodeExtensionFinder): List[String] = + _dump(highlight = false) - private def _dump(highlight: Boolean)(implicit finder: NodeExtensionFinder): List[String] = { - // initialized on first element as we need the graph for retrieving the metaData node. - // TODO: there should be a step to retrieve the metaData node for any node - // so we could avoid instantiating a new Cpg everytime using dump - var cpg: Cpg = null - traversal.map { node => - if (cpg == null) cpg = new Cpg(node.graph) - val language = cpg.metaData.language.headOption - val rootPath = cpg.metaData.root.headOption - CodeDumper.dump(node.location, language, rootPath, highlight) - }.l - } + private def _dump(highlight: Boolean)(implicit finder: NodeExtensionFinder): List[String] = + // initialized on first element as we need the graph for retrieving the metaData node. + // TODO: there should be a step to retrieve the metaData node for any node + // so we could avoid instantiating a new Cpg everytime using dump + var cpg: Cpg = null + traversal.map { node => + if cpg == null then cpg = new Cpg(node.graph) + val language = cpg.metaData.language.headOption + val rootPath = cpg.metaData.root.headOption + CodeDumper.dump(node.location, language, rootPath, highlight) + }.l - /* follow the incoming edges of the given type as long as possible */ - protected def walkIn(edgeType: String): Iterator[Node] = - traversal - .repeat(_.in(edgeType))(_.until(_.in(edgeType).countTrav.filter(_ == 0))) + /* follow the incoming edges of the given type as long as possible */ + protected def walkIn(edgeType: String): Iterator[Node] = + traversal + .repeat(_.in(edgeType))(_.until(_.in(edgeType).countTrav.filter(_ == 0))) - @Doc( - info = "Tag node with `tagName`", - longInfo = """ + @Doc( + info = "Tag node with `tagName`", + longInfo = """ |This method can be used to tag nodes in the graph such that |they can later be looked up easily via `cpg.tag`. Tags are |key value pairs, and they can be created with `newTagNodePair`. @@ -99,30 +100,31 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten |utility method `newTagNode(key)`, which is equivalent to |`newTagNode(key, "")`. |""", - example = """.newTagNode("foo")""" - ) - def newTagNode(tagName: String): NewTagNodePairTraversal = newTagNodePair(tagName, "") + example = """.newTagNode("foo")""" + ) + def newTagNode(tagName: String): NewTagNodePairTraversal = newTagNodePair(tagName, "") - @Doc(info = "Tag node with (`tagName`, `tagValue`)", longInfo = "", example = """.newTagNodePair("key","val")""") - def newTagNodePair(tagName: String, tagValue: String): NewTagNodePairTraversal = { - new NewTagNodePairTraversal(traversal.map { node => - NewTagNodePair() - .tag(NewTag().name(tagName).value(tagValue)) - .node(node) - }) - } + @Doc( + info = "Tag node with (`tagName`, `tagValue`)", + longInfo = "", + example = """.newTagNodePair("key","val")""" + ) + def newTagNodePair(tagName: String, tagValue: String): NewTagNodePairTraversal = + new NewTagNodePairTraversal(traversal.map { node => + NewTagNodePair() + .tag(NewTag().name(tagName).value(tagValue)) + .node(node) + }) - @Doc(info = "Tags attached to this node") - def tagList: List[List[Tag]] = - traversal.map { taggedNode => - taggedNode.tag.l - }.l + @Doc(info = "Tags attached to this node") + def tagList: List[List[Tag]] = + traversal.map { taggedNode => + taggedNode.tag.l + }.l - @Doc(info = "Tags attached to this node") - def tag: Iterator[Tag] = { - traversal.flatMap { node => - node.tag - } - } - -} + @Doc(info = "Tags attached to this node") + def tag: Iterator[Tag] = + traversal.flatMap { node => + node.tag + } +end NodeSteps diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala index c7e1f599..3bf04173 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala @@ -1,9 +1,9 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{NodeTypes, Properties} -import overflowdb._ +import overflowdb.* import overflowdb.traversal.help import overflowdb.traversal.help.Doc import overflowdb.traversal.{InitialTraversal, TraversalSource} @@ -11,308 +11,308 @@ import overflowdb.traversal.{InitialTraversal, TraversalSource} import scala.jdk.CollectionConverters.IteratorHasAsScala @help.TraversalSource -class NodeTypeStarters(cpg: Cpg) extends TraversalSource(cpg.graph) { - - /** Traverse to all nodes. - */ - @Doc(info = "All nodes of the graph") - override def all: Traversal[StoredNode] = - cpg.graph.nodes.asScala.cast[StoredNode] - - /** Traverse to all annotations - */ - def annotation: Traversal[Annotation] = - InitialTraversal.from[Annotation](cpg.graph, NodeTypes.ANNOTATION) - - /** Traverse to all arguments passed to methods - */ - @Doc(info = "All arguments (actual parameters)") - def argument: Traversal[Expression] = - call.argument - - /** Shorthand for `cpg.argument.code(code)` - */ - def argument(code: String): Traversal[Expression] = - argument.code(code) - - @Doc(info = "All breaks (`ControlStructure` nodes)") - def break: Traversal[ControlStructure] = - controlStructure.isBreak - - /** Traverse to all call sites - */ - @Doc(info = "All call sites") - def call: Traversal[Call] = - InitialTraversal.from[Call](cpg.graph, NodeTypes.CALL) - - /** Shorthand for `cpg.call.name(name)` - */ - def call(name: String): Traversal[Call] = - call.name(name) - - /** Traverse to all comments in source-based CPGs. - */ - @Doc(info = "All comments in source-based CPGs") - def comment: Traversal[Comment] = - InitialTraversal.from[Comment](cpg.graph, NodeTypes.COMMENT) - - /** Shorthand for `cpg.comment.code(code)` - */ - def comment(code: String): Traversal[Comment] = - comment.has(Properties.CODE -> code) - - /** Traverse to all config files - */ - @Doc(info = "All config files") - def configFile: Traversal[ConfigFile] = - InitialTraversal.from[ConfigFile](cpg.graph, NodeTypes.CONFIG_FILE) - - /** Shorthand for `cpg.configFile.name(name)` - */ - def configFile(name: String): Traversal[ConfigFile] = - configFile.name(name) - - /** Traverse to all dependencies - */ - @Doc(info = "All dependencies") - def dependency: Traversal[Dependency] = - InitialTraversal.from[Dependency](cpg.graph, NodeTypes.DEPENDENCY) - - /** Shorthand for `cpg.dependency.name(name)` - */ - def dependency(name: String): Traversal[Dependency] = - dependency.name(name) - - @Doc(info = "All control structures (source-based frontends)") - def controlStructure: Traversal[ControlStructure] = - InitialTraversal.from[ControlStructure](cpg.graph, NodeTypes.CONTROL_STRUCTURE) - - @Doc(info = "All continues (`ControlStructure` nodes)") - def continue: Traversal[ControlStructure] = - controlStructure.isContinue - - @Doc(info = "All do blocks (`ControlStructure` nodes)") - def doBlock: Traversal[ControlStructure] = - controlStructure.isDo - - @Doc(info = "All else blocks (`ControlStructure` nodes)") - def elseBlock: Traversal[ControlStructure] = - controlStructure.isElse - - @Doc(info = "All throws (`ControlStructure` nodes)") - def throws: Traversal[ControlStructure] = - controlStructure.isThrow - - /** Traverse to all source files - */ - @Doc(info = "All source files") - def file: Traversal[File] = - InitialTraversal.from[File](cpg.graph, NodeTypes.FILE) - - /** Shorthand for `cpg.file.name(name)` - */ - def file(name: String): Traversal[File] = - file.name(name) - - @Doc(info = "All for blocks (`ControlStructure` nodes)") - def forBlock: Traversal[ControlStructure] = - controlStructure.isFor - - @Doc(info = "All gotos (`ControlStructure` nodes)") - def goto: Traversal[ControlStructure] = - controlStructure.isGoto - - /** Traverse to all identifiers, e.g., occurrences of local variables or class members in method bodies. - */ - @Doc(info = "All identifier usages") - def identifier: Traversal[Identifier] = - InitialTraversal.from[Identifier](cpg.graph, NodeTypes.IDENTIFIER) - - /** Shorthand for `cpg.identifier.name(name)` - */ - def identifier(name: String): Traversal[Identifier] = - identifier.name(name) - - @Doc(info = "All if blocks (`ControlStructure` nodes)") - def ifBlock: Traversal[ControlStructure] = - controlStructure.isIf - - /** Traverse to all jump targets - */ - @Doc(info = "All jump targets, i.e., labels") - def jumpTarget: Traversal[JumpTarget] = - InitialTraversal.from[JumpTarget](cpg.graph, NodeTypes.JUMP_TARGET) - - /** Traverse to all local variable declarations - */ - @Doc(info = "All local variables") - def local: Traversal[Local] = - InitialTraversal.from[Local](cpg.graph, NodeTypes.LOCAL) - - /** Shorthand for `cpg.local.name` - */ - def local(name: String): Traversal[Local] = - local.name(name) - - /** Traverse to all literals (constant strings and numbers provided directly in the code). - */ - @Doc(info = "All literals, e.g., numbers or strings") - def literal: Traversal[Literal] = - InitialTraversal.from[Literal](cpg.graph, NodeTypes.LITERAL) - - /** Shorthand for `cpg.literal.code(code)` - */ - def literal(code: String): Traversal[Literal] = - literal.code(code) - - /** Traverse to all methods - */ - @Doc(info = "All methods") - def method: Traversal[Method] = - InitialTraversal.from[Method](cpg.graph, NodeTypes.METHOD) - - /** Shorthand for `cpg.method.name(name)` - */ - @Doc(info = "All methods with a name that matches the given pattern") - def method(namePattern: String): Traversal[Method] = - method.name(namePattern) - - /** Traverse to all formal return parameters - */ - @Doc(info = "All formal return parameters") - def methodReturn: Traversal[MethodReturn] = - InitialTraversal.from[MethodReturn](cpg.graph, NodeTypes.METHOD_RETURN) - - /** Traverse to all class members - */ - @Doc(info = "All members of complex types (e.g., classes/structures)") - def member: Traversal[Member] = - InitialTraversal.from[Member](cpg.graph, NodeTypes.MEMBER) - - /** Shorthand for `cpg.member.name(name)` - */ - def member(name: String): Traversal[Member] = - member.name(name) - - /** Traverse to all meta data entries - */ - @Doc(info = "Meta data blocks for graph") - def metaData: Traversal[MetaData] = - InitialTraversal.from[MetaData](cpg.graph, NodeTypes.META_DATA) - - /** Traverse to all method references - */ - @Doc(info = "All method references") - def methodRef: Traversal[MethodRef] = - InitialTraversal.from[MethodRef](cpg.graph, NodeTypes.METHOD_REF) - - /** Shorthand for `cpg.methodRef.filter(_.referencedMethod.name(name))` - */ - def methodRef(name: String): Traversal[MethodRef] = - methodRef.where(_.referencedMethod.name(name)) - - /** Traverse to all namespaces, e.g., packages in Java. - */ - @Doc(info = "All namespaces") - def namespace: Traversal[Namespace] = - InitialTraversal.from[Namespace](cpg.graph, NodeTypes.NAMESPACE) - - /** Shorthand for `cpg.namespace.name(name)` - */ - def namespace(name: String): Traversal[Namespace] = - namespace.name(name) - - /** Traverse to all namespace blocks, e.g., packages in Java. - */ - def namespaceBlock: Traversal[NamespaceBlock] = - InitialTraversal.from[NamespaceBlock](cpg.graph, NodeTypes.NAMESPACE_BLOCK) - - /** Shorthand for `cpg.namespaceBlock.name(name)` - */ - def namespaceBlock(name: String): Traversal[NamespaceBlock] = - namespaceBlock.name(name) - - /** Traverse to all input parameters - */ - @Doc(info = "All parameters") - def parameter: Traversal[MethodParameterIn] = - InitialTraversal.from[MethodParameterIn](cpg.graph, NodeTypes.METHOD_PARAMETER_IN) - - /** Shorthand for `cpg.parameter.name(name)` - */ - def parameter(name: String): Traversal[MethodParameterIn] = - parameter.name(name) - - /** Traverse to all return expressions - */ - @Doc(info = "All actual return parameters") - def ret: Traversal[Return] = - InitialTraversal.from[Return](cpg.graph, NodeTypes.RETURN) - - /** Shorthand for `returns.code(code)` - */ - def ret(code: String): Traversal[Return] = - ret.code(code) - - @Doc(info = "All imports") - def imports: Traversal[Import] = - InitialTraversal.from[Import](cpg.graph, NodeTypes.IMPORT) - - @Doc(info = "All switch blocks (`ControlStructure` nodes)") - def switchBlock: Traversal[ControlStructure] = - controlStructure.isSwitch - - @Doc(info = "All try blocks (`ControlStructure` nodes)") - def tryBlock: Traversal[ControlStructure] = - controlStructure.isTry - - /** Traverse to all types, e.g., Set - */ - @Doc(info = "All used types") - def typ: Traversal[Type] = - InitialTraversal.from[Type](cpg.graph, NodeTypes.TYPE) - - /** Shorthand for `cpg.typ.name(name)` - */ - @Doc(info = "All used types with given name") - def typ(name: String): Traversal[Type] = - typ.name(name) - - /** Traverse to all declarations, e.g., Set - */ - @Doc(info = "All declarations of types") - def typeDecl: Traversal[TypeDecl] = - InitialTraversal.from[TypeDecl](cpg.graph, NodeTypes.TYPE_DECL) - - /** Shorthand for cpg.typeDecl.name(name) - */ - def typeDecl(name: String): Traversal[TypeDecl] = - typeDecl.name(name) - - /** Traverse to all tags - */ - @Doc(info = "All tags") - def tag: Traversal[Tag] = - InitialTraversal.from[Tag](cpg.graph, NodeTypes.TAG) - - @Doc(info = "All tags with given name") - def tag(name: String): Traversal[Tag] = - tag.name(name) - - /** Traverse to all template DOM nodes - */ - @Doc(info = "All template DOM nodes") - def templateDom: Traversal[TemplateDom] = - InitialTraversal.from[TemplateDom](cpg.graph, NodeTypes.TEMPLATE_DOM) - - /** Traverse to all type references - */ - @Doc(info = "All type references") - def typeRef: Traversal[TypeRef] = - InitialTraversal.from[TypeRef](cpg.graph, NodeTypes.TYPE_REF) - - @Doc(info = "All while blocks (`ControlStructure` nodes)") - def whileBlock: Traversal[ControlStructure] = - controlStructure.isWhile - -} +class NodeTypeStarters(cpg: Cpg) extends TraversalSource(cpg.graph): + + /** Traverse to all nodes. + */ + @Doc(info = "All nodes of the graph") + override def all: Traversal[StoredNode] = + cpg.graph.nodes.asScala.cast[StoredNode] + + /** Traverse to all annotations + */ + def annotation: Traversal[Annotation] = + InitialTraversal.from[Annotation](cpg.graph, NodeTypes.ANNOTATION) + + /** Traverse to all arguments passed to methods + */ + @Doc(info = "All arguments (actual parameters)") + def argument: Traversal[Expression] = + call.argument + + /** Shorthand for `cpg.argument.code(code)` + */ + def argument(code: String): Traversal[Expression] = + argument.code(code) + + @Doc(info = "All breaks (`ControlStructure` nodes)") + def break: Traversal[ControlStructure] = + controlStructure.isBreak + + /** Traverse to all call sites + */ + @Doc(info = "All call sites") + def call: Traversal[Call] = + InitialTraversal.from[Call](cpg.graph, NodeTypes.CALL) + + /** Shorthand for `cpg.call.name(name)` + */ + def call(name: String): Traversal[Call] = + call.name(name) + + /** Traverse to all comments in source-based CPGs. + */ + @Doc(info = "All comments in source-based CPGs") + def comment: Traversal[Comment] = + InitialTraversal.from[Comment](cpg.graph, NodeTypes.COMMENT) + + /** Shorthand for `cpg.comment.code(code)` + */ + def comment(code: String): Traversal[Comment] = + comment.has(Properties.CODE -> code) + + /** Traverse to all config files + */ + @Doc(info = "All config files") + def configFile: Traversal[ConfigFile] = + InitialTraversal.from[ConfigFile](cpg.graph, NodeTypes.CONFIG_FILE) + + /** Shorthand for `cpg.configFile.name(name)` + */ + def configFile(name: String): Traversal[ConfigFile] = + configFile.name(name) + + /** Traverse to all dependencies + */ + @Doc(info = "All dependencies") + def dependency: Traversal[Dependency] = + InitialTraversal.from[Dependency](cpg.graph, NodeTypes.DEPENDENCY) + + /** Shorthand for `cpg.dependency.name(name)` + */ + def dependency(name: String): Traversal[Dependency] = + dependency.name(name) + + @Doc(info = "All control structures (source-based frontends)") + def controlStructure: Traversal[ControlStructure] = + InitialTraversal.from[ControlStructure](cpg.graph, NodeTypes.CONTROL_STRUCTURE) + + @Doc(info = "All continues (`ControlStructure` nodes)") + def continue: Traversal[ControlStructure] = + controlStructure.isContinue + + @Doc(info = "All do blocks (`ControlStructure` nodes)") + def doBlock: Traversal[ControlStructure] = + controlStructure.isDo + + @Doc(info = "All else blocks (`ControlStructure` nodes)") + def elseBlock: Traversal[ControlStructure] = + controlStructure.isElse + + @Doc(info = "All throws (`ControlStructure` nodes)") + def throws: Traversal[ControlStructure] = + controlStructure.isThrow + + /** Traverse to all source files + */ + @Doc(info = "All source files") + def file: Traversal[File] = + InitialTraversal.from[File](cpg.graph, NodeTypes.FILE) + + /** Shorthand for `cpg.file.name(name)` + */ + def file(name: String): Traversal[File] = + file.name(name) + + @Doc(info = "All for blocks (`ControlStructure` nodes)") + def forBlock: Traversal[ControlStructure] = + controlStructure.isFor + + @Doc(info = "All gotos (`ControlStructure` nodes)") + def goto: Traversal[ControlStructure] = + controlStructure.isGoto + + /** Traverse to all identifiers, e.g., occurrences of local variables or class members in method + * bodies. + */ + @Doc(info = "All identifier usages") + def identifier: Traversal[Identifier] = + InitialTraversal.from[Identifier](cpg.graph, NodeTypes.IDENTIFIER) + + /** Shorthand for `cpg.identifier.name(name)` + */ + def identifier(name: String): Traversal[Identifier] = + identifier.name(name) + + @Doc(info = "All if blocks (`ControlStructure` nodes)") + def ifBlock: Traversal[ControlStructure] = + controlStructure.isIf + + /** Traverse to all jump targets + */ + @Doc(info = "All jump targets, i.e., labels") + def jumpTarget: Traversal[JumpTarget] = + InitialTraversal.from[JumpTarget](cpg.graph, NodeTypes.JUMP_TARGET) + + /** Traverse to all local variable declarations + */ + @Doc(info = "All local variables") + def local: Traversal[Local] = + InitialTraversal.from[Local](cpg.graph, NodeTypes.LOCAL) + + /** Shorthand for `cpg.local.name` + */ + def local(name: String): Traversal[Local] = + local.name(name) + + /** Traverse to all literals (constant strings and numbers provided directly in the code). + */ + @Doc(info = "All literals, e.g., numbers or strings") + def literal: Traversal[Literal] = + InitialTraversal.from[Literal](cpg.graph, NodeTypes.LITERAL) + + /** Shorthand for `cpg.literal.code(code)` + */ + def literal(code: String): Traversal[Literal] = + literal.code(code) + + /** Traverse to all methods + */ + @Doc(info = "All methods") + def method: Traversal[Method] = + InitialTraversal.from[Method](cpg.graph, NodeTypes.METHOD) + + /** Shorthand for `cpg.method.name(name)` + */ + @Doc(info = "All methods with a name that matches the given pattern") + def method(namePattern: String): Traversal[Method] = + method.name(namePattern) + + /** Traverse to all formal return parameters + */ + @Doc(info = "All formal return parameters") + def methodReturn: Traversal[MethodReturn] = + InitialTraversal.from[MethodReturn](cpg.graph, NodeTypes.METHOD_RETURN) + + /** Traverse to all class members + */ + @Doc(info = "All members of complex types (e.g., classes/structures)") + def member: Traversal[Member] = + InitialTraversal.from[Member](cpg.graph, NodeTypes.MEMBER) + + /** Shorthand for `cpg.member.name(name)` + */ + def member(name: String): Traversal[Member] = + member.name(name) + + /** Traverse to all meta data entries + */ + @Doc(info = "Meta data blocks for graph") + def metaData: Traversal[MetaData] = + InitialTraversal.from[MetaData](cpg.graph, NodeTypes.META_DATA) + + /** Traverse to all method references + */ + @Doc(info = "All method references") + def methodRef: Traversal[MethodRef] = + InitialTraversal.from[MethodRef](cpg.graph, NodeTypes.METHOD_REF) + + /** Shorthand for `cpg.methodRef.filter(_.referencedMethod.name(name))` + */ + def methodRef(name: String): Traversal[MethodRef] = + methodRef.where(_.referencedMethod.name(name)) + + /** Traverse to all namespaces, e.g., packages in Java. + */ + @Doc(info = "All namespaces") + def namespace: Traversal[Namespace] = + InitialTraversal.from[Namespace](cpg.graph, NodeTypes.NAMESPACE) + + /** Shorthand for `cpg.namespace.name(name)` + */ + def namespace(name: String): Traversal[Namespace] = + namespace.name(name) + + /** Traverse to all namespace blocks, e.g., packages in Java. + */ + def namespaceBlock: Traversal[NamespaceBlock] = + InitialTraversal.from[NamespaceBlock](cpg.graph, NodeTypes.NAMESPACE_BLOCK) + + /** Shorthand for `cpg.namespaceBlock.name(name)` + */ + def namespaceBlock(name: String): Traversal[NamespaceBlock] = + namespaceBlock.name(name) + + /** Traverse to all input parameters + */ + @Doc(info = "All parameters") + def parameter: Traversal[MethodParameterIn] = + InitialTraversal.from[MethodParameterIn](cpg.graph, NodeTypes.METHOD_PARAMETER_IN) + + /** Shorthand for `cpg.parameter.name(name)` + */ + def parameter(name: String): Traversal[MethodParameterIn] = + parameter.name(name) + + /** Traverse to all return expressions + */ + @Doc(info = "All actual return parameters") + def ret: Traversal[Return] = + InitialTraversal.from[Return](cpg.graph, NodeTypes.RETURN) + + /** Shorthand for `returns.code(code)` + */ + def ret(code: String): Traversal[Return] = + ret.code(code) + + @Doc(info = "All imports") + def imports: Traversal[Import] = + InitialTraversal.from[Import](cpg.graph, NodeTypes.IMPORT) + + @Doc(info = "All switch blocks (`ControlStructure` nodes)") + def switchBlock: Traversal[ControlStructure] = + controlStructure.isSwitch + + @Doc(info = "All try blocks (`ControlStructure` nodes)") + def tryBlock: Traversal[ControlStructure] = + controlStructure.isTry + + /** Traverse to all types, e.g., Set + */ + @Doc(info = "All used types") + def typ: Traversal[Type] = + InitialTraversal.from[Type](cpg.graph, NodeTypes.TYPE) + + /** Shorthand for `cpg.typ.name(name)` + */ + @Doc(info = "All used types with given name") + def typ(name: String): Traversal[Type] = + typ.name(name) + + /** Traverse to all declarations, e.g., Set + */ + @Doc(info = "All declarations of types") + def typeDecl: Traversal[TypeDecl] = + InitialTraversal.from[TypeDecl](cpg.graph, NodeTypes.TYPE_DECL) + + /** Shorthand for cpg.typeDecl.name(name) + */ + def typeDecl(name: String): Traversal[TypeDecl] = + typeDecl.name(name) + + /** Traverse to all tags + */ + @Doc(info = "All tags") + def tag: Traversal[Tag] = + InitialTraversal.from[Tag](cpg.graph, NodeTypes.TAG) + + @Doc(info = "All tags with given name") + def tag(name: String): Traversal[Tag] = + tag.name(name) + + /** Traverse to all template DOM nodes + */ + @Doc(info = "All template DOM nodes") + def templateDom: Traversal[TemplateDom] = + InitialTraversal.from[TemplateDom](cpg.graph, NodeTypes.TEMPLATE_DOM) + + /** Traverse to all type references + */ + @Doc(info = "All type references") + def typeRef: Traversal[TypeRef] = + InitialTraversal.from[TypeRef](cpg.graph, NodeTypes.TYPE_REF) + + @Doc(info = "All while blocks (`ControlStructure` nodes)") + def whileBlock: Traversal[ControlStructure] = + controlStructure.isWhile +end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala index 0194d4dc..30a2fec7 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala @@ -3,40 +3,35 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.NewNode import overflowdb.Node -import scala.jdk.CollectionConverters._ +import scala.jdk.CollectionConverters.* /** Typeclass for (pretty) printing an object */ -trait Show[A] { - def apply(a: A): String -} - -object Show { - def default[A]: Show[A] = Default.asInstanceOf[Show[A]] - - private val Default = new Show[Any] { - override def apply(a: Any): String = a match { - case node: NewNode => - val label = node.label - val properties = propsToString(node.properties.toList) - s"($label): $properties" - - case node: Node => - val label = node.label - val id = node.id().toString - val properties = propsToString(node.propertiesMap.asScala.toList) - s"($label,$id): $properties" - - case other => other.toString - } - - private def propsToString(keyValues: List[(String, Any)]): String = { - keyValues - .filter(_._2.toString.nonEmpty) - .sortBy(_._1) - .map { case (key, value) => s"$key: $value" } - .mkString(", ") - } - } - -} +trait Show[A]: + def apply(a: A): String + +object Show: + def default[A]: Show[A] = Default.asInstanceOf[Show[A]] + + private val Default = new Show[Any]: + override def apply(a: Any): String = a match + case node: NewNode => + val label = node.label + val properties = propsToString(node.properties.toList) + s"($label): $properties" + + case node: Node => + val label = node.label + val id = node.id().toString + val properties = propsToString(node.propertiesMap.asScala.toList) + s"($label,$id): $properties" + + case other => other.toString + + private def propsToString(keyValues: List[(String, Any)]): String = + keyValues + .filter(_._2.toString.nonEmpty) + .sortBy(_._1) + .map { case (key, value) => s"$key: $value" } + .mkString(", ") +end Show diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala index 5d2d2acf..fa7f5c17 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala @@ -18,103 +18,98 @@ import me.shadaj.scalapy.interpreter.CPythonInterpreter import java.nio.file.Files -/** Base class for our DSL These are the base steps available in all steps of the query language. There are no - * constraints on the element types, unlike e.g. [[NodeSteps]] +/** Base class for our DSL These are the base steps available in all steps of the query language. + * There are no constraints on the element types, unlike e.g. [[NodeSteps]] */ -class Steps[A](val traversal: Iterator[A]) extends AnyVal { - - /** Execute the traversal and convert it to a mutable buffer - */ - def toBuffer(): mutable.Buffer[A] = traversal.to(mutable.Buffer) - - /** Shorthand for `toBuffer` - */ - def b: mutable.Buffer[A] = toBuffer() - - /** Alias for `toList` - * @deprecated - */ - def exec(): List[A] = traversal.toList - - /** Execute the travel and convert it to a Java stream. - */ - def toStream(): LazyList[A] = traversal.to(LazyList) - - /** Alias for `toStream` - */ - def s: LazyList[A] = toStream() - - /** Execute the traversal and convert it into a Java list (as opposed to the Scala list obtained via `toList`) - */ - def jl: JList[A] = b.asJava - - /** Execute this traversal and pretty print the results. This may mean that not all properties of the node are - * displayed or that some properties have undergone transformations to improve display. A good example is flow - * pretty-printing. This is the only three of the methods which we may modify on a per-node-type basis, typically via - * implicits of type Show[NodeType]. - */ - @Doc(info = "execute this traversal and pretty print the results") - def p(implicit show: Show[A] = Show.default): List[String] = - traversal.toList.map(show.apply) - - @Doc(info = "execute this traversal and print tabular result") - def t(implicit show: Show[A] = Show.default): Unit = { - traversal.toList.map(show.apply) - } - - @Doc(info = "execute this traversal and show the pretty-printed results in `less`") - // uses scala-repl-pp's `#|^` operator which let's `less` inherit stdin and stdout - def browse: Unit = { - given Colors = Colors.Default - traversal #|^ "less" - } - - /** Execute traversal and convert the result to json. `toJson` (export) contains the exact same information as - * `toList`, only in json format. Typically, the user will call this method upon inspection of the results of - * `toList` in order to export the data for processing with other tools. - */ - @Doc(info = "execute traversal and convert the result to json") - def toJson: String = toJson(pretty = false) - - /** Execute traversal and convert the result to pretty json. */ - @Doc(info = "execute traversal and convert the result to pretty json") - def toJsonPretty: String = toJson(pretty = true) - - protected def toJson(pretty: Boolean): String = { - implicit val formats: Formats = org.json4s.DefaultFormats + Steps.nodeSerializer - - val results = traversal.toList - if (pretty) writePretty(results) - else write(results) - } - - private def pyJson = py.module("json") - @Doc(info = "execute traversal and convert the result to python object") - def toPy: me.shadaj.scalapy.py.Dynamic = pyJson.loads(toJson(false)) - - def pyg = { - val tmpDir = Files.createTempDirectory("pyg-gml-export").toFile.getAbsolutePath - traversal match { - case methods: Iterator[Method] => { - val exportResult = methods.gml(tmpDir) - exportResult.files.map(Torch.to_pyg) - } - } - } -} - -object Steps { - private lazy val nodeSerializer = new CustomSerializer[AbstractNode](implicit format => - ( - { case _ => ??? }, - { case node: AbstractNode with Product => - val elementMap = (0 until node.productArity).map { i => - val label = node.productElementName(i) - val element = node.productElement(i) - label -> element - }.toMap + ("_label" -> node.label) - Extraction.decompose(elementMap) - } +class Steps[A](val traversal: Iterator[A]) extends AnyVal: + + /** Execute the traversal and convert it to a mutable buffer + */ + def toBuffer(): mutable.Buffer[A] = traversal.to(mutable.Buffer) + + /** Shorthand for `toBuffer` + */ + def b: mutable.Buffer[A] = toBuffer() + + /** Alias for `toList` + * @deprecated + */ + def exec(): List[A] = traversal.toList + + /** Execute the travel and convert it to a Java stream. + */ + def toStream(): LazyList[A] = traversal.to(LazyList) + + /** Alias for `toStream` + */ + def s: LazyList[A] = toStream() + + /** Execute the traversal and convert it into a Java list (as opposed to the Scala list obtained + * via `toList`) + */ + def jl: JList[A] = b.asJava + + /** Execute this traversal and pretty print the results. This may mean that not all properties + * of the node are displayed or that some properties have undergone transformations to improve + * display. A good example is flow pretty-printing. This is the only three of the methods which + * we may modify on a per-node-type basis, typically via implicits of type Show[NodeType]. + */ + @Doc(info = "execute this traversal and pretty print the results") + def p(implicit show: Show[A] = Show.default): List[String] = + traversal.toList.map(show.apply) + + @Doc(info = "execute this traversal and print tabular result") + def t(implicit show: Show[A] = Show.default): Unit = + traversal.toList.map(show.apply) + + @Doc(info = "execute this traversal and show the pretty-printed results in `less`") + // uses scala-repl-pp's `#|^` operator which let's `less` inherit stdin and stdout + def browse: Unit = + given Colors = Colors.Default + traversal #|^ "less" + + /** Execute traversal and convert the result to json. `toJson` (export) contains the exact same + * information as `toList`, only in json format. Typically, the user will call this method upon + * inspection of the results of `toList` in order to export the data for processing with other + * tools. + */ + @Doc(info = "execute traversal and convert the result to json") + def toJson: String = toJson(pretty = false) + + /** Execute traversal and convert the result to pretty json. */ + @Doc(info = "execute traversal and convert the result to pretty json") + def toJsonPretty: String = toJson(pretty = true) + + protected def toJson(pretty: Boolean): String = + implicit val formats: Formats = org.json4s.DefaultFormats + Steps.nodeSerializer + + val results = traversal.toList + if pretty then writePretty(results) + else write(results) + + private def pyJson = py.module("json") + @Doc(info = "execute traversal and convert the result to python object") + def toPy: me.shadaj.scalapy.py.Dynamic = pyJson.loads(toJson(false)) + + def pyg = + val tmpDir = Files.createTempDirectory("pyg-gml-export").toFile.getAbsolutePath + traversal match + case methods: Iterator[Method] => + val exportResult = methods.gml(tmpDir) + exportResult.files.map(Torch.to_pyg) +end Steps + +object Steps: + private lazy val nodeSerializer = new CustomSerializer[AbstractNode](implicit format => + ( + { case _ => ??? }, + { case node: AbstractNode with Product => + val elementMap = (0 until node.productArity).map { i => + val label = node.productElementName(i) + val element = node.productElement(i) + label -> element + }.toMap + ("_label" -> node.label) + Extraction.decompose(elementMap) + } + ) ) - ) -} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala index dc383b59..b4e088db 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala @@ -4,19 +4,18 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import scala.reflect.ClassTag -class TagTraversal(val traversal: Iterator[Tag]) extends AnyVal { +class TagTraversal(val traversal: Iterator[Tag]) extends AnyVal: - def member: Iterator[Member] = tagged[Member] - def method: Iterator[Method] = tagged[Method] - def methodReturn: Iterator[MethodReturn] = tagged[MethodReturn] - def parameter: Iterator[MethodParameterIn] = tagged[MethodParameterIn] - def parameterOut: Iterator[MethodParameterOut] = tagged[MethodParameterOut] - def call: Iterator[Call] = tagged[Call] - def identifier: Iterator[Identifier] = tagged[Identifier] - def literal: Iterator[Literal] = tagged[Literal] - def local: Iterator[Local] = tagged[Local] - def file: Iterator[File] = tagged[File] + def member: Iterator[Member] = tagged[Member] + def method: Iterator[Method] = tagged[Method] + def methodReturn: Iterator[MethodReturn] = tagged[MethodReturn] + def parameter: Iterator[MethodParameterIn] = tagged[MethodParameterIn] + def parameterOut: Iterator[MethodParameterOut] = tagged[MethodParameterOut] + def call: Iterator[Call] = tagged[Call] + def identifier: Iterator[Identifier] = tagged[Identifier] + def literal: Iterator[Literal] = tagged[Literal] + def local: Iterator[Local] = tagged[Local] + def file: Iterator[File] = tagged[File] - private def tagged[A <: StoredNode: ClassTag]: Iterator[A] = - traversal._taggedByIn.collectAll[A].sortBy(_.id).iterator -} + private def tagged[A <: StoredNode: ClassTag]: Iterator[A] = + traversal._taggedByIn.collectAll[A].sortBy(_.id).iterator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala index d8edbd4a..fe8f3722 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala @@ -3,102 +3,99 @@ package io.shiftleft.semanticcpg.language.android import io.shiftleft.semanticcpg.utils.SecureXmlParsing import io.shiftleft.codepropertygraph.generated.nodes -class ConfigFileTraversal(val traversal: Iterator[nodes.ConfigFile]) extends AnyVal { - def usesCleartextTraffic = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap { applicationNode => - val activityName = applicationNode.attribute(Constants.androidUri, "usesCleartextTraffic") - activityName.map(_.toString == "true") - } +class ConfigFileTraversal(val traversal: Iterator[nodes.ConfigFile]) extends AnyVal: + def usesCleartextTraffic = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "application") + .flatMap { applicationNode => + val activityName = + applicationNode.attribute(Constants.androidUri, "usesCleartextTraffic") + activityName.map(_.toString == "true") + } - def hasReadExternalStoragePermission = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "uses-permission") - .flatMap { applicationNode => - val activityName = applicationNode.attribute(Constants.androidUri, "name") - activityName match { - case Some(n) if n.toString == "android.permission.READ_EXTERNAL_STORAGE" => Some(true) - case _ => None - } - } + def hasReadExternalStoragePermission = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "uses-permission") + .flatMap { applicationNode => + val activityName = applicationNode.attribute(Constants.androidUri, "name") + activityName match + case Some(n) if n.toString == "android.permission.READ_EXTERNAL_STORAGE" => + Some(true) + case _ => None + } - def exportedAndroidActivityNames = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap(_.child) - .filter(_.label == "activity") - .flatMap { activityNode => - /* + def exportedAndroidActivityNames = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "application") + .flatMap(_.child) + .filter(_.label == "activity") + .flatMap { activityNode => + /* from: https://developer.android.com/guide/components/intents-filters Note: To receive implicit intents, you must include the CATEGORY_DEFAULT category in the intent filter. The methods startActivity() and startActivityForResult() treat all intents as if they declared the CATEGORY_DEFAULT category. If you do not declare this category in your intent filter, no implicit intents will resolve to your activity. - */ - val hasIntentFilterWithDefaultCategory = - activityNode + */ + val hasIntentFilterWithDefaultCategory = + activityNode + .flatMap(_.child) + .filter(_.label == "intent-filter") + .flatMap(_.child) + .filter(_.label == "category") + .exists { node => + val categoryName = node.attribute(Constants.androidUri, "name") + categoryName match + case Some(n) => n.toString == "android.intent.category.DEFAULT" + case None => false + } + if hasIntentFilterWithDefaultCategory then + val activityName = activityNode.attribute(Constants.androidUri, "name") + activityName match + case Some(n) => Some(n.toString) + case None => None + else None + } + + def exportedBroadcastReceiverNames = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") .flatMap(_.child) - .filter(_.label == "intent-filter") + .filter(_.label == "application") .flatMap(_.child) - .filter(_.label == "category") - .exists { node => - val categoryName = node.attribute(Constants.androidUri, "name") - categoryName match { - case Some(n) => n.toString == "android.intent.category.DEFAULT" - case None => false - } + .filter(_.label == "receiver") + .flatMap { receiverNode => + val hasIntentFilter = + receiverNode.flatMap(_.child).filter(_.label == "intent-filter").nonEmpty + if hasIntentFilter then + val isExported = receiverNode.attribute(Constants.androidUri, "exported") + isExported match + case Some(n) if n.toString == "true" => Some(receiverNode) + case _ => None + else None } - if (hasIntentFilterWithDefaultCategory) { - val activityName = activityNode.attribute(Constants.androidUri, "name") - activityName match { - case Some(n) => Some(n.toString) - case None => None - } - } else None - } - - def exportedBroadcastReceiverNames = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap(_.child) - .filter(_.label == "receiver") - .flatMap { receiverNode => - val hasIntentFilter = receiverNode.flatMap(_.child).filter(_.label == "intent-filter").nonEmpty - if (hasIntentFilter) { - val isExported = receiverNode.attribute(Constants.androidUri, "exported") - isExported match { - case Some(n) if n.toString == "true" => Some(receiverNode) - case _ => None - } - } else None - } - .flatMap { node => - val name = node.attribute(Constants.androidUri, "name") - name match { - case Some(n) => Some(n.toString().stripPrefix(".")) - case _ => None - } - } - -} + .flatMap { node => + val name = node.attribute(Constants.androidUri, "name") + name match + case Some(n) => Some(n.toString().stripPrefix(".")) + case _ => None + } +end ConfigFileTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala index 56667643..cef6d9bc 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala @@ -1,6 +1,5 @@ package io.shiftleft.semanticcpg.language.android -object Constants { - val androidUri = "http://schemas.android.com/apk/res/android" - val androidManifestXml = "AndroidManifest.xml" -} +object Constants: + val androidUri = "http://schemas.android.com/apk/res/android" + val androidManifestXml = "AndroidManifest.xml" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala index c4341613..7cbf0fe4 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala @@ -3,24 +3,24 @@ package io.shiftleft.semanticcpg.language.android import io.shiftleft.codepropertygraph.generated.nodes.Local import io.shiftleft.semanticcpg.language.* -class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal { - def callsEnableJS = - traversal - .where( - _.referencingIdentifiers.inCall - .nameExact("getSettings") - .where( - _.inCall - .nameExact("setJavaScriptEnabled") - .argument - .isLiteral - .codeExact("true") - ) - ) +class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal: + def callsEnableJS = + traversal + .where( + _.referencingIdentifiers.inCall + .nameExact("getSettings") + .where( + _.inCall + .nameExact("setJavaScriptEnabled") + .argument + .isLiteral + .codeExact("true") + ) + ) - def loadUrlCalls = - traversal.referencingIdentifiers.inCall.nameExact("loadUrl") + def loadUrlCalls = + traversal.referencingIdentifiers.inCall.nameExact("loadUrl") - def addJavascriptInterfaceCalls = - traversal.referencingIdentifiers.inCall.nameExact("addJavascriptInterface") -} + def addJavascriptInterfaceCalls = + traversal.referencingIdentifiers.inCall.nameExact("addJavascriptInterface") +end LocalTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala index 767a22a9..b4c05be0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala @@ -3,7 +3,6 @@ package io.shiftleft.semanticcpg.language.android import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.semanticcpg.language.* -class MethodTraversal(val traversal: Iterator[nodes.Method]) extends AnyVal { - def exposedToJS = - traversal.where(_.annotation.fullNameExact("android.webkit.JavascriptInterface")) -} +class MethodTraversal(val traversal: Iterator[nodes.Method]) extends AnyVal: + def exposedToJS = + traversal.where(_.annotation.fullNameExact("android.webkit.JavascriptInterface")) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala index b2e650f8..cdb97d8b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala @@ -4,35 +4,37 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -class NodeTypeStarters(cpg: Cpg) { - def webView: Iterator[Local] = - cpg.local.typeFullNameExact("android.webkit.WebView") - - def appManifest: Iterator[ConfigFile] = - cpg.configFile.filter(_.name.endsWith(Constants.androidManifestXml)) - - def getExternalStorageDir: Iterator[Call] = - cpg.call - .nameExact("getExternalStorageDirectory") - .where(_.argument(0).isIdentifier.typeFullNameExact("android.os.Environment")) - - def dexClassLoader: Iterator[Local] = - cpg.local.typeFullNameExact("dalvik.system.DexClassLoader") - - def broadcastReceivers: Iterator[TypeDecl] = - cpg.method - .nameExact("onReceive") - .where(_.parameter.index(1).typeFullNameExact("android.content.Context")) - .where(_.parameter.index(2).typeFullNameExact("android.content.Intent")) - .typeDecl - - def registerReceiver: Iterator[Call] = - cpg.call - .nameExact("registerReceiver") - .where(_.argument(2).isIdentifier.typeFullNameExact("android.content.IntentFilter")) - - def registeredBroadcastReceivers = - cpg.broadcastReceivers.filter { broadcastReceiver => - cpg.registerReceiver.argument(1).isIdentifier.typeFullName.exists(_ == broadcastReceiver.fullName) - } -} +class NodeTypeStarters(cpg: Cpg): + def webView: Iterator[Local] = + cpg.local.typeFullNameExact("android.webkit.WebView") + + def appManifest: Iterator[ConfigFile] = + cpg.configFile.filter(_.name.endsWith(Constants.androidManifestXml)) + + def getExternalStorageDir: Iterator[Call] = + cpg.call + .nameExact("getExternalStorageDirectory") + .where(_.argument(0).isIdentifier.typeFullNameExact("android.os.Environment")) + + def dexClassLoader: Iterator[Local] = + cpg.local.typeFullNameExact("dalvik.system.DexClassLoader") + + def broadcastReceivers: Iterator[TypeDecl] = + cpg.method + .nameExact("onReceive") + .where(_.parameter.index(1).typeFullNameExact("android.content.Context")) + .where(_.parameter.index(2).typeFullNameExact("android.content.Intent")) + .typeDecl + + def registerReceiver: Iterator[Call] = + cpg.call + .nameExact("registerReceiver") + .where(_.argument(2).isIdentifier.typeFullNameExact("android.content.IntentFilter")) + + def registeredBroadcastReceivers = + cpg.broadcastReceivers.filter { broadcastReceiver => + cpg.registerReceiver.argument(1).isIdentifier.typeFullName.exists( + _ == broadcastReceiver.fullName + ) + } +end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala index f0a65a32..6c0dc3c9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala @@ -4,27 +4,26 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{ConfigFile, Literal, Local, Method} /** Language extensions for android. */ -package object android { +package object android: - implicit def toNodeTypeStartersFlows(cpg: Cpg): NodeTypeStarters = - new NodeTypeStarters(cpg) + implicit def toNodeTypeStartersFlows(cpg: Cpg): NodeTypeStarters = + new NodeTypeStarters(cpg) - implicit def singleToLocalExt[A <: Local](a: A): LocalTraversal = - new LocalTraversal(Iterator.single(a)) + implicit def singleToLocalExt[A <: Local](a: A): LocalTraversal = + new LocalTraversal(Iterator.single(a)) - implicit def iterOnceToLocalExt[A <: Local](a: IterableOnce[A]): LocalTraversal = - new LocalTraversal(a.iterator) + implicit def iterOnceToLocalExt[A <: Local](a: IterableOnce[A]): LocalTraversal = + new LocalTraversal(a.iterator) - implicit def singleToConfigFileExt[A <: ConfigFile](a: A): ConfigFileTraversal = - new ConfigFileTraversal(Iterator.single(a)) + implicit def singleToConfigFileExt[A <: ConfigFile](a: A): ConfigFileTraversal = + new ConfigFileTraversal(Iterator.single(a)) - implicit def iterOnceToConfigFileExt[A <: ConfigFile](a: IterableOnce[A]): ConfigFileTraversal = - new ConfigFileTraversal(a.iterator) + implicit def iterOnceToConfigFileExt[A <: ConfigFile](a: IterableOnce[A]): ConfigFileTraversal = + new ConfigFileTraversal(a.iterator) - implicit def singleToMethodExt[A <: Method](a: A): MethodTraversal = - new MethodTraversal(Iterator.single(a)) + implicit def singleToMethodExt[A <: Method](a: A): MethodTraversal = + new MethodTraversal(Iterator.single(a)) - implicit def iterOnceToMethodExt[A <: Method](a: IterableOnce[A]): MethodTraversal = - new MethodTraversal(a.iterator) - -} + implicit def iterOnceToMethodExt[A <: Method](a: IterableOnce[A]): MethodTraversal = + new MethodTraversal(a.iterator) +end android diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala index 097813b9..9e222c97 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala @@ -3,15 +3,14 @@ package io.shiftleft.semanticcpg.language.bindingextension import io.shiftleft.codepropertygraph.generated.nodes.{Binding, Method, TypeDecl} import io.shiftleft.semanticcpg.language.* -class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { +class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: - /** Traverse to type decl which have this method bound to it. - */ - def bindingTypeDecl: Iterator[TypeDecl] = - referencingBinding.bindingTypeDecl + /** Traverse to type decl which have this method bound to it. + */ + def bindingTypeDecl: Iterator[TypeDecl] = + referencingBinding.bindingTypeDecl - /** Traverse to bindings which reference to this method. - */ - def referencingBinding: Iterator[Binding] = - traversal.flatMap(_._bindingViaRefIn) -} + /** Traverse to bindings which reference to this method. + */ + def referencingBinding: Iterator[Binding] = + traversal.flatMap(_._bindingViaRefIn) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala index 87cef32b..be9c73df 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala @@ -3,16 +3,14 @@ package io.shiftleft.semanticcpg.language.bindingextension import io.shiftleft.codepropertygraph.generated.nodes.{Binding, Method, TypeDecl} import io.shiftleft.semanticcpg.language.* -class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { +class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal: - /** Traverse to methods bound to this type decl. - */ - def boundMethod: Iterator[Method] = - methodBinding.boundMethod + /** Traverse to methods bound to this type decl. + */ + def boundMethod: Iterator[Method] = + methodBinding.boundMethod - /** Traverse to the method bindings of this type declaration. - */ - def methodBinding: Iterator[Binding] = - traversal.canonicalType.flatMap(_.bindsOut) - -} + /** Traverse to the method bindings of this type declaration. + */ + def methodBinding: Iterator[Binding] = + traversal.canonicalType.flatMap(_.bindsOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala index fc391e95..e4b18426 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala @@ -3,16 +3,14 @@ package io.shiftleft.semanticcpg.language.callgraphextension import io.shiftleft.codepropertygraph.generated.nodes.{Call, Import, Method} import io.shiftleft.semanticcpg.language.* -class CallTraversal(val traversal: Iterator[Call]) extends AnyVal { +class CallTraversal(val traversal: Iterator[Call]) extends AnyVal: - @deprecated("Use callee", "") - def calledMethod(implicit callResolver: ICallResolver): Iterator[Method] = callee + @deprecated("Use callee", "") + def calledMethod(implicit callResolver: ICallResolver): Iterator[Method] = callee - /** The callee method */ - def callee(implicit callResolver: ICallResolver): Iterator[Method] = - traversal.flatMap(callResolver.getCalledMethodsAsTraversal) + /** The callee method */ + def callee(implicit callResolver: ICallResolver): Iterator[Method] = + traversal.flatMap(callResolver.getCalledMethodsAsTraversal) - def referencedImports: Iterator[Import] = - traversal.flatMap(_._importViaIsCallForImportOut) - -} + def referencedImports: Iterator[Import] = + traversal.flatMap(_._importViaIsCallForImportOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala index 6b627ae4..911703c8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala @@ -4,64 +4,69 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, Method} import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { +class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: - /** Intended for internal use! Traverse to direct and transitive callers of the method. - */ - def calledByIncludingSink(sourceTrav: Iterator[Method])(implicit callResolver: ICallResolver): Iterator[Method] = { - val sourceMethods = sourceTrav.toSet - val sinkMethods = traversal.dedup + /** Intended for internal use! Traverse to direct and transitive callers of the method. + */ + def calledByIncludingSink(sourceTrav: Iterator[Method])(implicit + callResolver: ICallResolver + ): Iterator[Method] = + val sourceMethods = sourceTrav.toSet + val sinkMethods = traversal.dedup - if (sourceMethods.isEmpty || sinkMethods.isEmpty) { - Iterator.empty[Method].enablePathTracking - } else { - sinkMethods - .repeat( - _.flatMap(callResolver.getMethodCallsitesAsTraversal)._containsIn // expand to method - )(_.dedup.emit(_.collect { - case method: Method if sourceMethods.contains(method) => method - })) - .cast[Method] - } - } + if sourceMethods.isEmpty || sinkMethods.isEmpty then + Iterator.empty[Method].enablePathTracking + else + sinkMethods + .repeat( + _.flatMap( + callResolver.getMethodCallsitesAsTraversal + )._containsIn // expand to method + )(_.dedup.emit(_.collect { + case method: Method if sourceMethods.contains(method) => method + })) + .cast[Method] - /** Traverse to direct callers of this method - */ - def caller(implicit callResolver: ICallResolver): Iterator[Method] = - callIn(callResolver).method + /** Traverse to direct callers of this method + */ + def caller(implicit callResolver: ICallResolver): Iterator[Method] = + callIn(callResolver).method - /** Traverse to methods called by this method - */ - def callee(implicit callResolver: ICallResolver): Iterator[Method] = - call.callee(callResolver) + /** Traverse to methods called by this method + */ + def callee(implicit callResolver: ICallResolver): Iterator[Method] = + call.callee(callResolver) - /** Incoming call sites - */ - def callIn(implicit callResolver: ICallResolver): Iterator[Call] = - traversal.flatMap(method => callResolver.getMethodCallsitesAsTraversal(method).collectAll[Call]) + /** Incoming call sites + */ + def callIn(implicit callResolver: ICallResolver): Iterator[Call] = + traversal.flatMap(method => + callResolver.getMethodCallsitesAsTraversal(method).collectAll[Call] + ) - /** Traverse to direct and transitive callers of the method. - */ - def calledBy(sourceTrav: Iterator[Method])(implicit callResolver: ICallResolver): Iterator[Method] = - caller(callResolver).calledByIncludingSink(sourceTrav)(callResolver) + /** Traverse to direct and transitive callers of the method. + */ + def calledBy(sourceTrav: Iterator[Method])(implicit + callResolver: ICallResolver + ): Iterator[Method] = + caller(callResolver).calledByIncludingSink(sourceTrav)(callResolver) - @deprecated("Use call", "") - def callOut: Iterator[Call] = - call + @deprecated("Use call", "") + def callOut: Iterator[Call] = + call - @deprecated("Use call", "") - def callOutRegex(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = - call(regex) + @deprecated("Use call", "") + def callOutRegex(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = + call(regex) - /** Outgoing call sites to methods where fullName matches `regex`. - */ - def call(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = - call.where(_.callee.fullName(regex)) + /** Outgoing call sites to methods where fullName matches `regex`. + */ + def call(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = + call.where(_.callee.fullName(regex)) - /** Outgoing call sites - */ - @Doc(info = "Call sites (outgoing calls)") - def call: Iterator[Call] = - traversal.flatMap(_._callViaContainsOut) - -} + /** Outgoing call sites + */ + @Doc(info = "Call sites (outgoing calls)") + def call: Iterator[Call] = + traversal.flatMap(_._callViaContainsOut) +end MethodTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala index aa10a624..eee4a717 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala @@ -4,12 +4,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.semanticcpg.dotgenerator.DotAstGenerator import overflowdb.traversal.* -class AstNodeDot[NodeType <: AstNode](val traversal: Iterator[NodeType]) extends AnyVal { +class AstNodeDot[NodeType <: AstNode](val traversal: Iterator[NodeType]) extends AnyVal: - def dotAst: Iterator[String] = DotAstGenerator.dotAst(traversal) + def dotAst: Iterator[String] = DotAstGenerator.dotAst(traversal) - def plotDotAst(implicit viewer: ImageViewer): Unit = { - Shared.plotAndDisplay(dotAst.l, viewer) - } - -} + def plotDotAst(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotAst.l, viewer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala index 107b1903..3da7357f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala @@ -4,18 +4,14 @@ import io.shiftleft.codepropertygraph.generated.nodes.Method import io.shiftleft.semanticcpg.dotgenerator.{DotCdgGenerator, DotCfgGenerator} import overflowdb.traversal.* -class CfgNodeDot(val traversal: Iterator[Method]) extends AnyVal { +class CfgNodeDot(val traversal: Iterator[Method]) extends AnyVal: - def dotCfg: Iterator[String] = DotCfgGenerator.dotCfg(traversal) + def dotCfg: Iterator[String] = DotCfgGenerator.dotCfg(traversal) - def dotCdg: Iterator[String] = DotCdgGenerator.dotCdg(traversal) + def dotCdg: Iterator[String] = DotCdgGenerator.dotCdg(traversal) - def plotDotCfg(implicit viewer: ImageViewer): Unit = { - Shared.plotAndDisplay(dotCfg.l, viewer) - } + def plotDotCfg(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotCfg.l, viewer) - def plotDotCdg(implicit viewer: ImageViewer): Unit = { - Shared.plotAndDisplay(dotCdg.l, viewer) - } - -} + def plotDotCdg(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotCdg.l, viewer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala index 2b7e5195..fff4c131 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala @@ -3,10 +3,8 @@ package io.shiftleft.semanticcpg.language.dotextension import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.semanticcpg.dotgenerator.{DotCallGraphGenerator, DotTypeHierarchyGenerator} -class InterproceduralNodeDot(val cpg: Cpg) extends AnyVal { +class InterproceduralNodeDot(val cpg: Cpg) extends AnyVal: - def dotCallGraph: Iterator[String] = DotCallGraphGenerator.dotCallGraph(cpg) + def dotCallGraph: Iterator[String] = DotCallGraphGenerator.dotCallGraph(cpg) - def dotTypeHierarchy: Iterator[String] = DotTypeHierarchyGenerator.dotTypeHierarchy(cpg) - -} + def dotTypeHierarchy: Iterator[String] = DotTypeHierarchyGenerator.dotTypeHierarchy(cpg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala index f347486c..a58e70c5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala @@ -5,33 +5,36 @@ import better.files.File import scala.sys.process.Process import scala.util.{Failure, Success, Try} -trait ImageViewer { - def view(pathStr: String): Try[String] -} +trait ImageViewer: + def view(pathStr: String): Try[String] -object Shared { +object Shared: - def plotAndDisplay(dotStrings: List[String], viewer: ImageViewer): Unit = { - dotStrings.foreach { dotString => - File.usingTemporaryFile("semanticcpg") { dotFile => - File.usingTemporaryFile("semanticcpg") { svgFile => - dotFile.write(dotString) - createSvgFile(dotFile, svgFile).toOption.foreach(_ => viewer.view(svgFile.path.toAbsolutePath.toString)) + def plotAndDisplay(dotStrings: List[String], viewer: ImageViewer): Unit = + dotStrings.foreach { dotString => + File.usingTemporaryFile("semanticcpg") { dotFile => + File.usingTemporaryFile("semanticcpg") { svgFile => + dotFile.write(dotString) + createSvgFile(dotFile, svgFile).toOption.foreach(_ => + viewer.view(svgFile.path.toAbsolutePath.toString) + ) + } + } } - } - } - } - private def createSvgFile(in: File, out: File): Try[String] = { - Try { - Process(Seq("dot", "-Tsvg", in.path.toAbsolutePath.toString, "-o", out.path.toAbsolutePath.toString)).!! - } match { - case Success(v) => Success(v) - case Failure(exc) => - System.err.println("Executing `dot` failed: is `graphviz` installed?") - System.err.println(exc) - Failure(exc) - } - } - -} + private def createSvgFile(in: File, out: File): Try[String] = + Try { + Process(Seq( + "dot", + "-Tsvg", + in.path.toAbsolutePath.toString, + "-o", + out.path.toAbsolutePath.toString + )).!! + } match + case Success(v) => Success(v) + case Failure(exc) => + System.err.println("Executing `dot` failed: is `graphviz` installed?") + System.err.println(exc) + Failure(exc) +end Shared diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala index e8b1de42..1c5a17ce 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala @@ -7,141 +7,135 @@ import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.nodemethods.AstNodeMethods.lastExpressionInBlock import io.shiftleft.semanticcpg.utils.MemberAccess -class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension { +class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension: - /** Indicate whether the AST node represents a control structure, e.g., `if`, `for`, `while`. - */ - def isControlStructure: Boolean = node.isInstanceOf[ControlStructure] + /** Indicate whether the AST node represents a control structure, e.g., `if`, `for`, `while`. + */ + def isControlStructure: Boolean = node.isInstanceOf[ControlStructure] - def isIdentifier: Boolean = node.isInstanceOf[Identifier] + def isIdentifier: Boolean = node.isInstanceOf[Identifier] - def isImport: Boolean = node.isInstanceOf[Import] + def isImport: Boolean = node.isInstanceOf[Import] - def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier] + def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier] - def isFile: Boolean = node.isInstanceOf[File] + def isFile: Boolean = node.isInstanceOf[File] - def isReturn: Boolean = node.isInstanceOf[Return] + def isReturn: Boolean = node.isInstanceOf[Return] - def isLiteral: Boolean = node.isInstanceOf[Literal] + def isLiteral: Boolean = node.isInstanceOf[Literal] - def isLocal: Boolean = node.isInstanceOf[Local] + def isLocal: Boolean = node.isInstanceOf[Local] - def isCall: Boolean = node.isInstanceOf[Call] + def isCall: Boolean = node.isInstanceOf[Call] - def isExpression: Boolean = node.isInstanceOf[Expression] + def isExpression: Boolean = node.isInstanceOf[Expression] - def isMember: Boolean = node.isInstanceOf[Member] + def isMember: Boolean = node.isInstanceOf[Member] - def isMethodRef: Boolean = node.isInstanceOf[MethodRef] + def isMethodRef: Boolean = node.isInstanceOf[MethodRef] - def isMethod: Boolean = node.isInstanceOf[Method] + def isMethod: Boolean = node.isInstanceOf[Method] - def isModifier: Boolean = node.isInstanceOf[Modifier] + def isModifier: Boolean = node.isInstanceOf[Modifier] - def isNamespaceBlock: Boolean = node.isInstanceOf[NamespaceBlock] + def isNamespaceBlock: Boolean = node.isInstanceOf[NamespaceBlock] - def isBlock: Boolean = node.isInstanceOf[Block] + def isBlock: Boolean = node.isInstanceOf[Block] - def isParameter: Boolean = node.isInstanceOf[MethodParameterIn] + def isParameter: Boolean = node.isInstanceOf[MethodParameterIn] - def isTypeDecl: Boolean = node.isInstanceOf[TypeDecl] + def isTypeDecl: Boolean = node.isInstanceOf[TypeDecl] - def depth: Int = depth(_ => true) + def depth: Int = depth(_ => true) - /** The depth of the AST rooted in this node. Upon walking the tree to its leaves, the depth is only increased for - * nodes where `p(node)` is true. - */ - def depth(p: AstNode => Boolean): Int = { - val additionalDepth = if (p(node)) { 1 } - else { 0 } + /** The depth of the AST rooted in this node. Upon walking the tree to its leaves, the depth is + * only increased for nodes where `p(node)` is true. + */ + def depth(p: AstNode => Boolean): Int = + val additionalDepth = if p(node) then 1 + else 0 - val childDepths = node.astChildren.map(_.depth(p)).l - additionalDepth + (if (childDepths.isEmpty) { - 0 - } else { - childDepths.max - }) - } + val childDepths = node.astChildren.map(_.depth(p)).l + additionalDepth + (if childDepths.isEmpty then + 0 + else + childDepths.max + ) - def astParent: AstNode = { - try { - node._astIn.onlyChecked.asInstanceOf[AstNode] - } catch { - case _: Throwable => node._astIn.asInstanceOf[AstNode] - } - } + def astParent: AstNode = + try + node._astIn.onlyChecked.asInstanceOf[AstNode] + catch + case _: Throwable => node._astIn.asInstanceOf[AstNode] - /** Direct children of node in the AST. Siblings are ordered by their `order` fields - */ - def astChildren: Iterator[AstNode] = - node._astOut.cast[AstNode].sortBy(_.order).iterator + /** Direct children of node in the AST. Siblings are ordered by their `order` fields + */ + def astChildren: Iterator[AstNode] = + node._astOut.cast[AstNode].sortBy(_.order).iterator - /** Siblings of this node in the AST, ordered by their `order` fields - */ - def astSiblings: Iterator[AstNode] = - astParent.astChildren.filter(_ != node) + /** Siblings of this node in the AST, ordered by their `order` fields + */ + def astSiblings: Iterator[AstNode] = + astParent.astChildren.filter(_ != node) - /** Nodes of the AST rooted in this node, including the node itself. - */ - def ast: Iterator[AstNode] = - Iterator.single(node).ast + /** Nodes of the AST rooted in this node, including the node itself. + */ + def ast: Iterator[AstNode] = + Iterator.single(node).ast - /** Textual representation of AST node - */ - def repr: String = - node match { - case method: Method => method.name - case member: Member => member.name - case methodReturn: MethodReturn => methodReturn.code - case expr: Expression => expr.code - case call: CallRepr if !call.isInstanceOf[Call] => call.code - } + /** Textual representation of AST node + */ + def repr: String = + node match + case method: Method => method.name + case member: Member => member.name + case methodReturn: MethodReturn => methodReturn.code + case expr: Expression => expr.code + case call: CallRepr if !call.isInstanceOf[Call] => call.code - def statement: AstNode = - statementInternal(node, _.parentExpression.get) + def statement: AstNode = + statementInternal(node, _.parentExpression.get) - @scala.annotation.tailrec - private def statementInternal(node: AstNode, parentExpansion: Expression => Expression): AstNode = { + @scala.annotation.tailrec + private def statementInternal( + node: AstNode, + parentExpansion: Expression => Expression + ): AstNode = + node match + case node: Identifier => parentExpansion(node) + case node: MethodRef => parentExpansion(node) + case node: TypeRef => parentExpansion(node) + case node: Literal => parentExpansion(node) - node match { - case node: Identifier => parentExpansion(node) - case node: MethodRef => parentExpansion(node) - case node: TypeRef => parentExpansion(node) - case node: Literal => parentExpansion(node) + case member: Member => member + case node: MethodParameterIn => node.method - case member: Member => member - case node: MethodParameterIn => node.method + case node: MethodParameterOut => + node.method.methodReturn - case node: MethodParameterOut => - node.method.methodReturn + case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => + parentExpansion(node) - case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => - parentExpansion(node) - - case node: CallRepr => node - case node: MethodReturn => node - case block: Block => - // Just taking the lastExpressionInBlock is not quite correct because a BLOCK could have - // different return expressions. So we would need to expand via CFG. - // But currently the frontends do not even put the BLOCK into the CFG so this is the best - // we can do. - statementInternal(lastExpressionInBlock(block).get, identity) - case node: Expression => node - } - } - -} - -object AstNodeMethods { - - private def lastExpressionInBlock(block: Block): Option[Expression] = - block._astOut - .collect { - case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => node - } - .toVector - .sortBy(_.order) - .lastOption - -} + case node: CallRepr => node + case node: MethodReturn => node + case block: Block => + // Just taking the lastExpressionInBlock is not quite correct because a BLOCK could have + // different return expressions. So we would need to expand via CFG. + // But currently the frontends do not even put the BLOCK into the CFG so this is the best + // we can do. + statementInternal(lastExpressionInBlock(block).get, identity) + case node: Expression => node +end AstNodeMethods + +object AstNodeMethods: + + private def lastExpressionInBlock(block: Block): Option[Expression] = + block._astOut + .collect { + case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => + node + } + .toVector + .sortBy(_.order) + .lastOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala index 26bf5b7a..edba3f54 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala @@ -4,28 +4,27 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, Expression, NewLoca import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* -class CallMethods(val node: Call) extends AnyVal with NodeExtension with HasLocation { - def receiver: Iterator[Expression] = - node.receiverOut +class CallMethods(val node: Call) extends AnyVal with NodeExtension with HasLocation: + def receiver: Iterator[Expression] = + node.receiverOut - def arguments(index: Int): Iterator[Expression] = - node._argumentOut - .collect { - case expr: Expression if expr.argumentIndex == index => expr - } + def arguments(index: Int): Iterator[Expression] = + node._argumentOut + .collect { + case expr: Expression if expr.argumentIndex == index => expr + } - def argument: Iterator[Expression] = - node._argumentOut.collectAll[Expression] + def argument: Iterator[Expression] = + node._argumentOut.collectAll[Expression] - def argument(index: Int): Expression = - arguments(index).head + def argument(index: Int): Expression = + arguments(index).head - def argumentOption(index: Int): Option[Expression] = - node._argumentOut.collectFirst { - case expr: Expression if expr.argumentIndex == index => expr - } + def argumentOption(index: Int): Option[Expression] = + node._argumentOut.collectFirst { + case expr: Expression if expr.argumentIndex == index => expr + } - override def location: NewLocation = { - LocationCreator(node, node.code, node.label, node.lineNumber, node.method) - } -} + override def location: NewLocation = + LocationCreator(node, node.code, node.label, node.lineNumber, node.method) +end CallMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala index fc61285a..3c6f05a2 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala @@ -7,157 +7,140 @@ import io.shiftleft.semanticcpg.language.* import scala.jdk.CollectionConverters.* -class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension { - - /** Successors in the CFG - */ - def cfgNext: Iterator[CfgNode] = { - Iterator.single(node).cfgNext - } - - /** Maps each node in the traversal to a traversal returning its n successors. - */ - def cfgNext(n: Int): Iterator[CfgNode] = n match { - case 0 => Iterator.empty - case _ => cfgNext.flatMap(x => List(x) ++ x.cfgNext(n - 1)) - } - - /** Maps each node in the traversal to a traversal returning its n predecessors. - */ - def cfgPrev(n: Int): Iterator[CfgNode] = n match { - case 0 => Iterator.empty - case _ => cfgPrev.flatMap(x => List(x) ++ x.cfgPrev(n - 1)) - } - - /** Predecessors in the CFG - */ - def cfgPrev: Iterator[CfgNode] = { - Iterator.single(node).cfgPrev - } - - /** Recursively determine all nodes on which this CFG node is control-dependent. - */ - def controlledBy: Iterator[CfgNode] = { - expandExhaustively { v => - v._cdgIn - } - } - - /** Recursively determine all nodes which this CFG node controls - */ - def controls: Iterator[CfgNode] = { - expandExhaustively { v => - v._cdgOut - } - } - - /** Recursively determine all nodes by which this node is dominated - */ - def dominatedBy: Iterator[CfgNode] = { - expandExhaustively { v => - v._dominateIn - } - } - - /** Recursively determine all nodes which are dominated by this node - */ - def dominates: Iterator[CfgNode] = { - expandExhaustively { v => - v._dominateOut - } - } - - /** Recursively determine all nodes by which this node is post dominated - */ - def postDominatedBy: Iterator[CfgNode] = { - expandExhaustively { v => - v._postDominateIn - } - } - - /** Recursively determine all nodes which are post dominated by this node - */ - def postDominates: Iterator[CfgNode] = { - expandExhaustively { v => - v._postDominateOut - } - } - - /** Using the post dominator tree, will determine if this node passes through the included set of nodes and filter it - * in. - * @param included - * the nodes this node must pass through. - * @return - * the traversal of this node if it passes through the included set. - */ - def passes(included: Set[CfgNode]): Iterator[CfgNode] = - Iterator.single(node).filter(_.postDominatedBy.exists(included.contains)) - - /** Using the post dominator tree, will determine if this node passes through the excluded set of nodes and filter it - * out. - * @param excluded - * the nodes this node must not pass through. - * @return - * the traversal of this node if it does not pass through the excluded set. - */ - def passesNot(excluded: Set[CfgNode]): Iterator[CfgNode] = - Iterator.single(node).filterNot(_.postDominatedBy.exists(excluded.contains)) - - private def expandExhaustively(expand: CfgNode => Iterator[StoredNode]): Iterator[CfgNode] = { - var controllingNodes = List.empty[CfgNode] - var visited = Set.empty + node - var worklist = node :: Nil - - while (worklist.nonEmpty) { - val vertex = worklist.head - worklist = worklist.tail - - expand(vertex).foreach { case controllingNode: CfgNode => - if (!visited.contains(controllingNode)) { - visited += controllingNode - controllingNodes = controllingNode :: controllingNodes - worklist = controllingNode :: worklist +class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension: + + /** Successors in the CFG + */ + def cfgNext: Iterator[CfgNode] = + Iterator.single(node).cfgNext + + /** Maps each node in the traversal to a traversal returning its n successors. + */ + def cfgNext(n: Int): Iterator[CfgNode] = n match + case 0 => Iterator.empty + case _ => cfgNext.flatMap(x => List(x) ++ x.cfgNext(n - 1)) + + /** Maps each node in the traversal to a traversal returning its n predecessors. + */ + def cfgPrev(n: Int): Iterator[CfgNode] = n match + case 0 => Iterator.empty + case _ => cfgPrev.flatMap(x => List(x) ++ x.cfgPrev(n - 1)) + + /** Predecessors in the CFG + */ + def cfgPrev: Iterator[CfgNode] = + Iterator.single(node).cfgPrev + + /** Recursively determine all nodes on which this CFG node is control-dependent. + */ + def controlledBy: Iterator[CfgNode] = + expandExhaustively { v => + v._cdgIn } - } - } - controllingNodes.iterator - } - - def method: Method = node match { - case node: Method => node - case _: MethodParameterIn | _: MethodParameterOut | _: MethodReturn => - walkUpAst(node) - case _: CallRepr if !node.isInstanceOf[Call] => walkUpAst(node) - case _: Expression | _: JumpTarget => walkUpContains(node) - } - - /** Obtain hexadecimal string representation of lineNumber field. - * - * Binary frontends store addresses in the lineNumber field as integers. For interoperability with other binary - * analysis tooling, it is convenient to allow retrieving these as hex strings. - */ - def address: Option[String] = { - node.lineNumber.map(_.toLong.toHexString) - } - - private def walkUpAst(node: CfgNode): Method = - node._astIn.onlyChecked.asInstanceOf[Method] - - private def walkUpContains(node: StoredNode): Method = - node._containsIn.onlyChecked match { - case method: Method => method - case typeDecl: TypeDecl => - typeDecl.astParent match { - case namespaceBlock: NamespaceBlock => - // For Typescript, types may be declared in namespaces which we represent as NamespaceBlocks - namespaceBlock.inAst.collectAll[Method].headOption.orNull - case method: Method => - // For a language such as Javascript, types may be dynamically declared under procedures - method - case _ => - // there are csharp CPGs that have typedecls here, which is invalid. - null + + /** Recursively determine all nodes which this CFG node controls + */ + def controls: Iterator[CfgNode] = + expandExhaustively { v => + v._cdgOut + } + + /** Recursively determine all nodes by which this node is dominated + */ + def dominatedBy: Iterator[CfgNode] = + expandExhaustively { v => + v._dominateIn + } + + /** Recursively determine all nodes which are dominated by this node + */ + def dominates: Iterator[CfgNode] = + expandExhaustively { v => + v._dominateOut + } + + /** Recursively determine all nodes by which this node is post dominated + */ + def postDominatedBy: Iterator[CfgNode] = + expandExhaustively { v => + v._postDominateIn + } + + /** Recursively determine all nodes which are post dominated by this node + */ + def postDominates: Iterator[CfgNode] = + expandExhaustively { v => + v._postDominateOut } - } -} + /** Using the post dominator tree, will determine if this node passes through the included set + * of nodes and filter it in. + * @param included + * the nodes this node must pass through. + * @return + * the traversal of this node if it passes through the included set. + */ + def passes(included: Set[CfgNode]): Iterator[CfgNode] = + Iterator.single(node).filter(_.postDominatedBy.exists(included.contains)) + + /** Using the post dominator tree, will determine if this node passes through the excluded set + * of nodes and filter it out. + * @param excluded + * the nodes this node must not pass through. + * @return + * the traversal of this node if it does not pass through the excluded set. + */ + def passesNot(excluded: Set[CfgNode]): Iterator[CfgNode] = + Iterator.single(node).filterNot(_.postDominatedBy.exists(excluded.contains)) + + private def expandExhaustively(expand: CfgNode => Iterator[StoredNode]): Iterator[CfgNode] = + var controllingNodes = List.empty[CfgNode] + var visited = Set.empty + node + var worklist = node :: Nil + + while worklist.nonEmpty do + val vertex = worklist.head + worklist = worklist.tail + + expand(vertex).foreach { case controllingNode: CfgNode => + if !visited.contains(controllingNode) then + visited += controllingNode + controllingNodes = controllingNode :: controllingNodes + worklist = controllingNode :: worklist + } + controllingNodes.iterator + + def method: Method = node match + case node: Method => node + case _: MethodParameterIn | _: MethodParameterOut | _: MethodReturn => + walkUpAst(node) + case _: CallRepr if !node.isInstanceOf[Call] => walkUpAst(node) + case _: Expression | _: JumpTarget => walkUpContains(node) + + /** Obtain hexadecimal string representation of lineNumber field. + * + * Binary frontends store addresses in the lineNumber field as integers. For interoperability + * with other binary analysis tooling, it is convenient to allow retrieving these as hex + * strings. + */ + def address: Option[String] = + node.lineNumber.map(_.toLong.toHexString) + + private def walkUpAst(node: CfgNode): Method = + node._astIn.onlyChecked.asInstanceOf[Method] + + private def walkUpContains(node: StoredNode): Method = + node._containsIn.onlyChecked match + case method: Method => method + case typeDecl: TypeDecl => + typeDecl.astParent match + case namespaceBlock: NamespaceBlock => + // For Typescript, types may be declared in namespaces which we represent as NamespaceBlocks + namespaceBlock.inAst.collectAll[Method].headOption.orNull + case method: Method => + // For a language such as Javascript, types may be dynamically declared under procedures + method + case _ => + // there are csharp CPGs that have typedecls here, which is invalid. + null +end CfgNodeMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala index 75f80fd5..1df2e6e3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala @@ -13,64 +13,56 @@ import scala.jdk.CollectionConverters.* // Many method of this class should return individual nodes instead of Traversal[...]. // But over time through some opague implicits the versions returning Traversal[...] // got exposed and for now we do not want to break the API. -class ExpressionMethods(val node: Expression) extends AnyVal with NodeExtension { +class ExpressionMethods(val node: Expression) extends AnyVal with NodeExtension: - /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's continuing it's walk - * until it hits an expression that's not a generic "member access operation", e.g., ".memberAccess". - */ - def parentExpression: Option[Expression] = _parentExpression(node) + /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's + * continuing it's walk until it hits an expression that's not a generic "member access + * operation", e.g., ".memberAccess". + */ + def parentExpression: Option[Expression] = _parentExpression(node) - @tailrec - private final def _parentExpression(argument: AstNode): Option[Expression] = { - val parent = argument._astIn.onlyChecked - parent match { - case call: Call if MemberAccess.isGenericMemberAccessName(call.name) => - _parentExpression(call) - case expression: Expression => - Some(expression) - case annotationParameterAssign: AnnotationParameterAssign => - _parentExpression(annotationParameterAssign) - case _ => - None - } - } + @tailrec + private final def _parentExpression(argument: AstNode): Option[Expression] = + val parent = argument._astIn.onlyChecked + parent match + case call: Call if MemberAccess.isGenericMemberAccessName(call.name) => + _parentExpression(call) + case expression: Expression => + Some(expression) + case annotationParameterAssign: AnnotationParameterAssign => + _parentExpression(annotationParameterAssign) + case _ => + None - def expressionUp: Iterator[Expression] = { - node._astIn.collectAll[Expression] - } + def expressionUp: Iterator[Expression] = + node._astIn.collectAll[Expression] - def expressionDown: Iterator[Expression] = { - node._astOut.collectAll[Expression] - } + def expressionDown: Iterator[Expression] = + node._astOut.collectAll[Expression] - def receivedCall: Iterator[Call] = { - node._receiverIn.cast[Call] - } + def receivedCall: Iterator[Call] = + node._receiverIn.cast[Call] - def isArgument: Iterator[Expression] = { - if (node._argumentIn.hasNext) Iterator.single(node) - else Iterator.empty - } + def isArgument: Iterator[Expression] = + if node._argumentIn.hasNext then Iterator.single(node) + else Iterator.empty - def inCall: Iterator[Call] = - node._argumentIn.headOption match { - case Some(c: Call) => Iterator.single(c) - case _ => Iterator.empty - } + def inCall: Iterator[Call] = + node._argumentIn.headOption match + case Some(c: Call) => Iterator.single(c) + case _ => Iterator.empty - def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = { - // Expressions can have incoming argument edges not just from CallRepr nodes but also - // from Return nodes for which an expansion to parameter makes no sense. So we filter - // for CallRepr. - for { - call <- node._argumentIn if call.isInstanceOf[CallRepr] - calledMethods <- callResolver.getCalledMethods(call.asInstanceOf[CallRepr]) - paramIn <- calledMethods._astOut.collectAll[MethodParameterIn] - if paramIn.index == node.argumentIndex - } yield paramIn - } + def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = + // Expressions can have incoming argument edges not just from CallRepr nodes but also + // from Return nodes for which an expansion to parameter makes no sense. So we filter + // for CallRepr. + for + call <- node._argumentIn if call.isInstanceOf[CallRepr] + calledMethods <- callResolver.getCalledMethods(call.asInstanceOf[CallRepr]) + paramIn <- calledMethods._astOut.collectAll[MethodParameterIn] + if paramIn.index == node.argumentIndex + yield paramIn - def typ: Iterator[Type] = - node._evalTypeOut.cast[Type] - -} + def typ: Iterator[Type] = + node._evalTypeOut.cast[Type] +end ExpressionMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala index 59dde532..3c9e6646 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala @@ -2,10 +2,15 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, _} +import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, *} -class IdentifierMethods(val identifier: Identifier) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(identifier, identifier.name, identifier.label, identifier.lineNumber, identifier.method) - } -} +class IdentifierMethods(val identifier: Identifier) extends AnyVal with NodeExtension + with HasLocation: + override def location: NewLocation = + LocationCreator( + identifier, + identifier.name, + identifier.label, + identifier.lineNumber, + identifier.method + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala index 3bbc9892..c3e78471 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala @@ -2,11 +2,8 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{Literal, NewLocation} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, _} +import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, *} -class LiteralMethods(val literal: Literal) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(literal, literal.code, literal.label, literal.lineNumber, literal.method) - - } -} +class LiteralMethods(val literal: Literal) extends AnyVal with NodeExtension with HasLocation: + override def location: NewLocation = + LocationCreator(literal, literal.code, literal.label, literal.lineNumber, literal.method) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala index 9bfa5eac..d65271d5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala @@ -4,13 +4,11 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Local, Method, NewLocatio import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* -class LocalMethods(val local: Local) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(local, local.name, local.label, local.lineNumber, local.method.head) - } +class LocalMethods(val local: Local) extends AnyVal with NodeExtension with HasLocation: + override def location: NewLocation = + LocationCreator(local, local.name, local.label, local.lineNumber, local.method.head) - /** The method hosting this local variable - */ - def method: Iterator[Method] = - Iterator.single(local).method -} + /** The method hosting this local variable + */ + def method: Iterator[Method] = + Iterator.single(local).method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala index 1bb743b6..7d832e4c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala @@ -4,64 +4,61 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* -class MethodMethods(val method: Method) extends AnyVal with NodeExtension with HasLocation { +class MethodMethods(val method: Method) extends AnyVal with NodeExtension with HasLocation: - /** Traverse to annotations of method - */ - def annotation: Iterator[Annotation] = - method._annotationViaAstOut + /** Traverse to annotations of method + */ + def annotation: Iterator[Annotation] = + method._annotationViaAstOut - def local: Iterator[Local] = - method._blockViaContainsOut.local + def local: Iterator[Local] = + method._blockViaContainsOut.local - /** All control structures of this method - */ - def controlStructure: Iterator[ControlStructure] = - method.ast.isControlStructure + /** All control structures of this method + */ + def controlStructure: Iterator[ControlStructure] = + method.ast.isControlStructure - def numberOfLines: Int = { - if (method.lineNumber.isDefined && method.lineNumberEnd.isDefined) { - method.lineNumberEnd.get - method.lineNumber.get + 1 - } else { - 0 - } - } + def numberOfLines: Int = + if method.lineNumber.isDefined && method.lineNumberEnd.isDefined then + method.lineNumberEnd.get - method.lineNumber.get + 1 + else + 0 - def isVariadic: Boolean = { - method.parameter.exists(_.isVariadic) - } + def isVariadic: Boolean = + method.parameter.exists(_.isVariadic) - def cfgNode: Iterator[CfgNode] = - method._containsOut.collectAll[CfgNode] + def cfgNode: Iterator[CfgNode] = + method._containsOut.collectAll[CfgNode] - /** List of CFG nodes in reverse post order - */ - def reversePostOrder: Iterator[CfgNode] = { - def expand(x: CfgNode) = { x.cfgNext.iterator } - NodeOrdering.reverseNodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator - } + /** List of CFG nodes in reverse post order + */ + def reversePostOrder: Iterator[CfgNode] = + def expand(x: CfgNode) = x.cfgNext.iterator + NodeOrdering.reverseNodeList( + NodeOrdering.postOrderNumbering(method, expand).toList + ).iterator - /** List of CFG nodes in post order - */ - def postOrder: Iterator[CfgNode] = { - def expand(x: CfgNode) = { x.cfgNext.iterator } - NodeOrdering.nodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator - } + /** List of CFG nodes in post order + */ + def postOrder: Iterator[CfgNode] = + def expand(x: CfgNode) = x.cfgNext.iterator + NodeOrdering.nodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator - /** The type declaration associated with this method, e.g., the class it is defined in. - */ - def definingTypeDecl: Option[TypeDecl] = - Iterator.single(method).definingTypeDecl.headOption + /** The type declaration associated with this method, e.g., the class it is defined in. + */ + def definingTypeDecl: Option[TypeDecl] = + Iterator.single(method).definingTypeDecl.headOption - /** The type declaration associated with this method, e.g., the class it is defined in. Alias for 'definingTypeDecl' - */ - def typeDecl: Option[TypeDecl] = definingTypeDecl + /** The type declaration associated with this method, e.g., the class it is defined in. Alias + * for 'definingTypeDecl' + */ + def typeDecl: Option[TypeDecl] = definingTypeDecl - /** Traverse to method body (alias for `block`) */ - def body: Block = - method.block + /** Traverse to method body (alias for `block`) */ + def body: Block = + method.block - override def location: NewLocation = { - LocationCreator(method, method.name, method.label, method.lineNumber, method) - } -} + override def location: NewLocation = + LocationCreator(method, method.name, method.label, method.lineNumber, method) +end MethodMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala index 1f2a0c54..9e87f382 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala @@ -4,8 +4,7 @@ import io.shiftleft.codepropertygraph.generated.nodes.{MethodParameterIn, NewLoc import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} -class MethodParameterInMethods(val paramIn: MethodParameterIn) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(paramIn, paramIn.name, paramIn.label, paramIn.lineNumber, paramIn.method) - } -} +class MethodParameterInMethods(val paramIn: MethodParameterIn) extends AnyVal with NodeExtension + with HasLocation: + override def location: NewLocation = + LocationCreator(paramIn, paramIn.name, paramIn.label, paramIn.lineNumber, paramIn.method) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala index 851c5949..34623ff6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala @@ -4,8 +4,13 @@ import io.shiftleft.codepropertygraph.generated.nodes.{MethodParameterOut, NewLo import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} -class MethodParameterOutMethods(val paramOut: MethodParameterOut) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(paramOut, paramOut.name, paramOut.label, paramOut.lineNumber, paramOut.method) - } -} +class MethodParameterOutMethods(val paramOut: MethodParameterOut) extends AnyVal with NodeExtension + with HasLocation: + override def location: NewLocation = + LocationCreator( + paramOut, + paramOut.name, + paramOut.label, + paramOut.lineNumber, + paramOut.method + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala index fc971fff..a9ad7d9e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala @@ -4,14 +4,12 @@ import io.shiftleft.codepropertygraph.generated.nodes.{MethodRef, NewLocation} import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} -class MethodRefMethods(val methodRef: MethodRef) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator( - methodRef, - methodRef.code, - methodRef.label, - methodRef.lineNumber, - methodRef._methodViaContainsIn.next() - ) - } -} +class MethodRefMethods(val methodRef: MethodRef) extends AnyVal with NodeExtension with HasLocation: + override def location: NewLocation = + LocationCreator( + methodRef, + methodRef.code, + methodRef.label, + methodRef.lineNumber, + methodRef._methodViaContainsIn.next() + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala index 4ea45873..206b926a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala @@ -4,20 +4,18 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, MethodReturn, NewLo import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* -class MethodReturnMethods(val node: MethodReturn) extends AnyVal with NodeExtension with HasLocation { - override def location: NewLocation = { - LocationCreator(node, "$ret", node.label, node.lineNumber, node.method) - } +class MethodReturnMethods(val node: MethodReturn) extends AnyVal with NodeExtension + with HasLocation: + override def location: NewLocation = + LocationCreator(node, "$ret", node.label, node.lineNumber, node.method) - def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = { - val method = node._methodViaAstIn - val callsites = callResolver.getMethodCallsites(method) - // TODO for now we filter away all implicit calls because a change of the - // return type to CallRepr would lead to a break in the API aka - // the DSL steps which are subsequently allowed to be called. Before - // we addressed this we can only return Call instances. - callsites.collectAll[Call] - } + def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = + val method = node._methodViaAstIn + val callsites = callResolver.getMethodCallsites(method) + // TODO for now we filter away all implicit calls because a change of the + // return type to CallRepr would lead to a break in the API aka + // the DSL steps which are subsequently allowed to be called. Before + // we addressed this we can only return Call instances. + callsites.collectAll[Call] - def typ: Iterator[Type] = node.evalTypeOut -} + def typ: Iterator[Type] = node.evalTypeOut diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala index 600fe2c1..dfa54eba 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala @@ -2,15 +2,12 @@ package io.shiftleft.semanticcpg.language.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.{NewLocation, StoredNode} import io.shiftleft.semanticcpg.NodeExtension -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.NodeOrDetachedNode -class NodeMethods(val node: NodeOrDetachedNode) extends AnyVal with NodeExtension { +class NodeMethods(val node: NodeOrDetachedNode) extends AnyVal with NodeExtension: - def location(implicit finder: NodeExtensionFinder): NewLocation = - node match { - case storedNode: StoredNode => LocationCreator(storedNode) - case _ => LocationCreator.emptyLocation("", None) - } - -} + def location(implicit finder: NodeExtensionFinder): NewLocation = + node match + case storedNode: StoredNode => LocationCreator(storedNode) + case _ => LocationCreator.emptyLocation("", None) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala index 09f042d4..658df335 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala @@ -6,13 +6,11 @@ import io.shiftleft.semanticcpg.language.* import scala.jdk.CollectionConverters.* -class StoredNodeMethods(val node: StoredNode) extends AnyVal with NodeExtension { - def tag: Iterator[Tag] = { - node._taggedByOut - .cast[Tag] - .distinctBy(tag => (tag.name, tag.value)) - } +class StoredNodeMethods(val node: StoredNode) extends AnyVal with NodeExtension: + def tag: Iterator[Tag] = + node._taggedByOut + .cast[Tag] + .distinctBy(tag => (tag.name, tag.value)) - def file: Iterator[File] = - Iterator.single(node).file -} + def file: Iterator[File] = + Iterator.single(node).file diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala index cc3b2519..9e153997 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala @@ -1,31 +1,31 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -class ArrayAccessTraversal(val traversal: Iterator[OpNodes.ArrayAccess]) extends AnyVal { +class ArrayAccessTraversal(val traversal: Iterator[OpNodes.ArrayAccess]) extends AnyVal: - @Doc(info = "The expression representing the array") - def array: Iterator[Expression] = traversal.map(_.array) + @Doc(info = "The expression representing the array") + def array: Iterator[Expression] = traversal.map(_.array) - @Doc(info = "Offset at which the array is referenced (an expression)") - def offset: Iterator[Expression] = traversal.map(_.offset) + @Doc(info = "Offset at which the array is referenced (an expression)") + def offset: Iterator[Expression] = traversal.map(_.offset) - @Doc(info = "All identifiers that are part of the offset") - def subscript: Iterator[Identifier] = traversal.flatMap(_.subscript) + @Doc(info = "All identifiers that are part of the offset") + def subscript: Iterator[Identifier] = traversal.flatMap(_.subscript) - @Doc( - info = "Determine whether array access has constant offset", - longInfo = """ + @Doc( + info = "Determine whether array access has constant offset", + longInfo = """ Determine if array access is at constant numeric offset, e.g., `buf[10]` but not `buf[i + 10]`, and for simplicity, not even `buf[1+2]`, `buf[PROBABLY_A_CONSTANT]` or `buf[PROBABLY_A_CONSTANT + 1]`, or even `buf[PROBABLY_A_CONSTANT]`. """ - ) - def usesConstantOffset: Iterator[OpNodes.ArrayAccess] = traversal.filter(_.usesConstantOffset) + ) + def usesConstantOffset: Iterator[OpNodes.ArrayAccess] = traversal.filter(_.usesConstantOffset) - @Doc(info = "If `array` is a lone identifier, return its name") - def simpleName: Iterator[String] = traversal.flatMap(_.simpleName) -} + @Doc(info = "If `array` is a lone identifier, return its name") + def simpleName: Iterator[String] = traversal.flatMap(_.simpleName) +end ArrayAccessTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala index 25cbca5c..4b0966c7 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala @@ -6,11 +6,10 @@ import overflowdb.traversal.help import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[OpNodes.Assignment]) -class AssignmentTraversal(val traversal: Iterator[OpNodes.Assignment]) extends AnyVal { +class AssignmentTraversal(val traversal: Iterator[OpNodes.Assignment]) extends AnyVal: - @Doc(info = "Left-hand sides of assignments") - def target: Iterator[Expression] = traversal.map(_.target) + @Doc(info = "Left-hand sides of assignments") + def target: Iterator[Expression] = traversal.map(_.target) - @Doc(info = "Right-hand sides of assignments") - def source: Iterator[Expression] = traversal.map(_.source) -} + @Doc(info = "Right-hand sides of assignments") + def source: Iterator[Expression] = traversal.map(_.source) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala index 6bb06f0c..1e58d4cc 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala @@ -4,21 +4,19 @@ import io.shiftleft.codepropertygraph.generated.nodes.{FieldIdentifier, Member, import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -class FieldAccessTraversal(val traversal: Iterator[OpNodes.FieldAccess]) extends AnyVal { +class FieldAccessTraversal(val traversal: Iterator[OpNodes.FieldAccess]) extends AnyVal: - @Doc(info = "Attempts to resolve the type declaration for this field access") - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.typeDecl) + @Doc(info = "Attempts to resolve the type declaration for this field access") + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.typeDecl) - // TODO there are cases for the C++ frontend where argument(2) is a CALL or IDENTIFIER, - // and we are not handling them at the moment + // TODO there are cases for the C++ frontend where argument(2) is a CALL or IDENTIFIER, + // and we are not handling them at the moment - @Doc(info = "The identifier of the referenced field (right-hand side)") - def fieldIdentifier: Iterator[FieldIdentifier] = - traversal.flatMap(_.fieldIdentifier) + @Doc(info = "The identifier of the referenced field (right-hand side)") + def fieldIdentifier: Iterator[FieldIdentifier] = + traversal.flatMap(_.fieldIdentifier) - @Doc(info = "Attempts to resolve the member referenced by this field access") - def member: Iterator[Member] = - traversal.flatMap(_.member) - -} + @Doc(info = "Attempts to resolve the member referenced by this field access") + def member: Iterator[Member] = + traversal.flatMap(_.member) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala index bfea44db..8b30c56c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala @@ -2,29 +2,33 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Expression} -import io.shiftleft.semanticcpg.language.operatorextension.nodemethods._ - -trait Implicits { - implicit def toNodeTypeStartersOperatorExtension(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) - - implicit def toArrayAccessExt(arrayAccess: OpNodes.ArrayAccess): ArrayAccessMethods = - new ArrayAccessMethods(arrayAccess) - implicit def toArrayAccessTrav(steps: Iterator[OpNodes.ArrayAccess]): ArrayAccessTraversal = - new ArrayAccessTraversal(steps) - - implicit def toFieldAccessExt(fieldAccess: OpNodes.FieldAccess): FieldAccessMethods = - new FieldAccessMethods(fieldAccess) - implicit def toFieldAccessTrav(steps: Iterator[OpNodes.FieldAccess]): FieldAccessTraversal = - new FieldAccessTraversal(steps) - - implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = new AssignmentMethods(assignment) - implicit def toAssignmentTrav(steps: Iterator[OpNodes.Assignment]): AssignmentTraversal = - new AssignmentTraversal(steps) - - implicit def toTargetExt(call: Expression): TargetMethods = new TargetMethods(call) - implicit def toTargetTrav(steps: Iterator[Expression]): TargetTraversal = new TargetTraversal(steps) - - implicit def toOpAstNodeExt[A <: AstNode](node: A): OpAstNodeMethods[A] = new OpAstNodeMethods(node) - implicit def toOpAstNodeTrav[A <: AstNode](steps: Iterator[A]): OpAstNodeTraversal[A] = new OpAstNodeTraversal(steps) - -} +import io.shiftleft.semanticcpg.language.operatorextension.nodemethods.* + +trait Implicits: + implicit def toNodeTypeStartersOperatorExtension(cpg: Cpg): NodeTypeStarters = + new NodeTypeStarters(cpg) + + implicit def toArrayAccessExt(arrayAccess: OpNodes.ArrayAccess): ArrayAccessMethods = + new ArrayAccessMethods(arrayAccess) + implicit def toArrayAccessTrav(steps: Iterator[OpNodes.ArrayAccess]): ArrayAccessTraversal = + new ArrayAccessTraversal(steps) + + implicit def toFieldAccessExt(fieldAccess: OpNodes.FieldAccess): FieldAccessMethods = + new FieldAccessMethods(fieldAccess) + implicit def toFieldAccessTrav(steps: Iterator[OpNodes.FieldAccess]): FieldAccessTraversal = + new FieldAccessTraversal(steps) + + implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = + new AssignmentMethods(assignment) + implicit def toAssignmentTrav(steps: Iterator[OpNodes.Assignment]): AssignmentTraversal = + new AssignmentTraversal(steps) + + implicit def toTargetExt(call: Expression): TargetMethods = new TargetMethods(call) + implicit def toTargetTrav(steps: Iterator[Expression]): TargetTraversal = + new TargetTraversal(steps) + + implicit def toOpAstNodeExt[A <: AstNode](node: A): OpAstNodeMethods[A] = + new OpAstNodeMethods(node) + implicit def toOpAstNodeTrav[A <: AstNode](steps: Iterator[A]): OpAstNodeTraversal[A] = + new OpAstNodeTraversal(steps) +end Implicits diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala index 46b35d1e..74a50f60 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala @@ -7,29 +7,32 @@ import overflowdb.traversal.help.{Doc, TraversalSource} /** Steps that allow traversing from `cpg` to operators. */ @TraversalSource -class NodeTypeStarters(cpg: Cpg) { - - @Doc(info = "All assignments, including shorthand assignments that perform arithmetic (e.g., '+=')") - def assignment: Iterator[OpNodes.Assignment] = - callsWithNameIn(allAssignmentTypes) - .map(new OpNodes.Assignment(_)) - - @Doc(info = "All arithmetic operations, including shorthand assignments that perform arithmetic (e.g., '+=')") - def arithmetic: Iterator[OpNodes.Arithmetic] = - callsWithNameIn(allArithmeticTypes) - .map(new OpNodes.Arithmetic(_)) - - @Doc(info = "All array accesses") - def arrayAccess: Iterator[OpNodes.ArrayAccess] = - callsWithNameIn(allArrayAccessTypes) - .map(new OpNodes.ArrayAccess(_)) - - @Doc(info = "Field accesses, both direct and indirect") - def fieldAccess: Iterator[OpNodes.FieldAccess] = - callsWithNameIn(allFieldAccessTypes) - .map(new OpNodes.FieldAccess(_)) - - private def callsWithNameIn(set: Set[String]) = - cpg.call.filter(x => set.contains(x.name)) - -} +class NodeTypeStarters(cpg: Cpg): + + @Doc(info = + "All assignments, including shorthand assignments that perform arithmetic (e.g., '+=')" + ) + def assignment: Iterator[OpNodes.Assignment] = + callsWithNameIn(allAssignmentTypes) + .map(new OpNodes.Assignment(_)) + + @Doc(info = + "All arithmetic operations, including shorthand assignments that perform arithmetic (e.g., '+=')" + ) + def arithmetic: Iterator[OpNodes.Arithmetic] = + callsWithNameIn(allArithmeticTypes) + .map(new OpNodes.Arithmetic(_)) + + @Doc(info = "All array accesses") + def arrayAccess: Iterator[OpNodes.ArrayAccess] = + callsWithNameIn(allArrayAccessTypes) + .map(new OpNodes.ArrayAccess(_)) + + @Doc(info = "Field accesses, both direct and indirect") + def fieldAccess: Iterator[OpNodes.FieldAccess] = + callsWithNameIn(allFieldAccessTypes) + .map(new OpNodes.FieldAccess(_)) + + private def callsWithNameIn(set: Set[String]) = + cpg.call.filter(x => set.contains(x.name)) +end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala index d5c61829..3a7f04da 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala @@ -4,31 +4,30 @@ import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -class OpAstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { +class OpAstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: - @Doc(info = "Any assignments that this node is a part of (traverse up)") - def assignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.assignment) + @Doc(info = "Any assignments that this node is a part of (traverse up)") + def assignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.assignment) - @Doc(info = "Arithmetic expressions nested in this tree") - def arithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.arithmetic) + @Doc(info = "Arithmetic expressions nested in this tree") + def arithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.arithmetic) - @Doc(info = "All array accesses") - def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) + @Doc(info = "All array accesses") + def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) - @Doc(info = "Field accesses, both direct and indirect") - def fieldAccess: Iterator[OpNodes.FieldAccess] = - traversal.flatMap(_.fieldAccess) + @Doc(info = "Field accesses, both direct and indirect") + def fieldAccess: Iterator[OpNodes.FieldAccess] = + traversal.flatMap(_.fieldAccess) - @Doc(info = "Any assignments that this node is a part of (traverse up)") - def inAssignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.inAssignment) + @Doc(info = "Any assignments that this node is a part of (traverse up)") + def inAssignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.inAssignment) - @Doc(info = "Any arithmetic expression that this node is a part of (traverse up)") - def inArithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.inArithmetic) + @Doc(info = "Any arithmetic expression that this node is a part of (traverse up)") + def inArithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.inArithmetic) - @Doc(info = "Any array access that this node is a part of (traverse up)") - def inArrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.inArrayAccess) + @Doc(info = "Any array access that this node is a part of (traverse up)") + def inArrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.inArrayAccess) - @Doc(info = "Any field access that this node is a part of (traverse up)") - def inFieldAccess: Iterator[OpNodes.FieldAccess] = traversal.flatMap(_.inFieldAccess) - -} + @Doc(info = "Any field access that this node is a part of (traverse up)") + def inFieldAccess: Iterator[OpNodes.FieldAccess] = traversal.flatMap(_.inFieldAccess) +end OpAstNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala index ad6c2c43..67ed7bdc 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala @@ -2,9 +2,8 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.generated.nodes.Call -object OpNodes { - class Assignment(call: Call) extends Call(call.graph, call.id) - class Arithmetic(call: Call) extends Call(call.graph, call.id) - class ArrayAccess(call: Call) extends Call(call.graph, call.id) - class FieldAccess(call: Call) extends Call(call.graph, call.id) -} +object OpNodes: + class Assignment(call: Call) extends Call(call.graph, call.id) + class Arithmetic(call: Call) extends Call(call.graph, call.id) + class ArrayAccess(call: Call) extends Call(call.graph, call.id) + class FieldAccess(call: Call) extends Call(call.graph, call.id) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala index 3a17608b..01a8c5a7 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala @@ -4,18 +4,16 @@ import io.shiftleft.codepropertygraph.generated.nodes.Expression import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -class TargetTraversal(val traversal: Iterator[Expression]) extends AnyVal { +class TargetTraversal(val traversal: Iterator[Expression]) extends AnyVal: - @Doc( - info = "Outer-most array access", - longInfo = """ + @Doc( + info = "Outer-most array access", + longInfo = """ Array access at highest level , e.g., in a(b(c)), the entire expression is returned, but not b(c) alone. """ - ) - def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) + ) + def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) - @Doc(info = "Returns 'pointer' in assignments of the form *(pointer) = x") - def pointer: Iterator[Expression] = traversal.flatMap(_.pointer) - -} + @Doc(info = "Returns 'pointer' in assignments of the form *(pointer) = x") + def pointer: Iterator[Expression] = traversal.flatMap(_.pointer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala index 86755916..1ba8115a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala @@ -4,28 +4,24 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Expression, Identifier} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes -class ArrayAccessMethods(val arrayAccess: OpNodes.ArrayAccess) extends AnyVal { +class ArrayAccessMethods(val arrayAccess: OpNodes.ArrayAccess) extends AnyVal: - def array: Expression = - arrayAccess.argument(1) + def array: Expression = + arrayAccess.argument(1) - def offset: Expression = arrayAccess.argument(2) + def offset: Expression = arrayAccess.argument(2) - def subscript: Iterator[Identifier] = - offset.ast.isIdentifier + def subscript: Iterator[Identifier] = + offset.ast.isIdentifier - def usesConstantOffset: Boolean = { - offset.ast.isIdentifier.nonEmpty || { - val literalIndices = offset.ast.isLiteral.l - literalIndices.size == 1 - } - } + def usesConstantOffset: Boolean = + offset.ast.isIdentifier.nonEmpty || { + val literalIndices = offset.ast.isLiteral.l + literalIndices.size == 1 + } - def simpleName: Iterator[String] = { - arrayAccess.array match { - case id: Identifier => Iterator.single(id.name) - case _ => Iterator.empty - } - } - -} + def simpleName: Iterator[String] = + arrayAccess.array match + case id: Identifier => Iterator.single(id.name) + case _ => Iterator.empty +end ArrayAccessMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala index 4faa0722..0fefd0fd 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala @@ -1,18 +1,17 @@ package io.shiftleft.semanticcpg.language.operatorextension.nodemethods import io.shiftleft.codepropertygraph.generated.nodes.Expression -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes -class AssignmentMethods(val assignment: OpNodes.Assignment) extends AnyVal { +class AssignmentMethods(val assignment: OpNodes.Assignment) extends AnyVal: - def target: Expression = assignment.argument(1) + def target: Expression = assignment.argument(1) - def source: Expression = { - assignment.argument.size match { - case 1 => assignment.argument(1) - case 2 => assignment.argument(2) - case numberOfArguments => throw new RuntimeException(s"Assignment statement with $numberOfArguments arguments") - } - } -} + def source: Expression = + assignment.argument.size match + case 1 => assignment.argument(1) + case 2 => assignment.argument(2) + case numberOfArguments => throw new RuntimeException( + s"Assignment statement with $numberOfArguments arguments" + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala index 64525909..50c62e70 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala @@ -4,22 +4,18 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes -class FieldAccessMethods(val arrayAccess: OpNodes.FieldAccess) extends AnyVal { +class FieldAccessMethods(val arrayAccess: OpNodes.FieldAccess) extends AnyVal: - def typeDecl: Iterator[TypeDecl] = resolveTypeDecl(arrayAccess.argument(1)) + def typeDecl: Iterator[TypeDecl] = resolveTypeDecl(arrayAccess.argument(1)) - private def resolveTypeDecl(expr: Expression): Iterator[TypeDecl] = { - expr match { - case x: Identifier => x.typ.referencedTypeDecl - case x: Literal => x.typ.referencedTypeDecl - case x: Call => x.fieldAccess.member.typ.referencedTypeDecl - case _ => Iterator.empty - } - } + private def resolveTypeDecl(expr: Expression): Iterator[TypeDecl] = + expr match + case x: Identifier => x.typ.referencedTypeDecl + case x: Literal => x.typ.referencedTypeDecl + case x: Call => x.fieldAccess.member.typ.referencedTypeDecl + case _ => Iterator.empty - def fieldIdentifier: Iterator[FieldIdentifier] = arrayAccess.start.argument(2).isFieldIdentifier + def fieldIdentifier: Iterator[FieldIdentifier] = arrayAccess.start.argument(2).isFieldIdentifier - def member: Option[Member] = - arrayAccess.referencedMember.headOption - -} + def member: Option[Member] = + arrayAccess.referencedMember.headOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala index f221f045..6fa6dcd3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala @@ -4,41 +4,40 @@ import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call} import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.* -class OpAstNodeMethods[A <: AstNode](val node: A) extends AnyVal { +class OpAstNodeMethods[A <: AstNode](val node: A) extends AnyVal: - def assignment: Iterator[OpNodes.Assignment] = - astDown(allAssignmentTypes).map(new OpNodes.Assignment(_)) + def assignment: Iterator[OpNodes.Assignment] = + astDown(allAssignmentTypes).map(new OpNodes.Assignment(_)) - def arithmetic: Iterator[OpNodes.Arithmetic] = - astDown(allArithmeticTypes).map(new OpNodes.Arithmetic(_)) + def arithmetic: Iterator[OpNodes.Arithmetic] = + astDown(allArithmeticTypes).map(new OpNodes.Arithmetic(_)) - def arrayAccess: Iterator[OpNodes.ArrayAccess] = - astDown(allArrayAccessTypes).map(new OpNodes.ArrayAccess(_)) + def arrayAccess: Iterator[OpNodes.ArrayAccess] = + astDown(allArrayAccessTypes).map(new OpNodes.ArrayAccess(_)) - def fieldAccess: Iterator[OpNodes.FieldAccess] = - astDown(allFieldAccessTypes).map(new OpNodes.FieldAccess(_)) + def fieldAccess: Iterator[OpNodes.FieldAccess] = + astDown(allFieldAccessTypes).map(new OpNodes.FieldAccess(_)) - private def astDown(callNames: Set[String]): Iterator[Call] = - node.ast.isCall.filter(x => callNames.contains(x.name)) + private def astDown(callNames: Set[String]): Iterator[Call] = + node.ast.isCall.filter(x => callNames.contains(x.name)) - def inAssignment: Iterator[OpNodes.Assignment] = - astUp(allAssignmentTypes) - .map(new OpNodes.Assignment(_)) + def inAssignment: Iterator[OpNodes.Assignment] = + astUp(allAssignmentTypes) + .map(new OpNodes.Assignment(_)) - def inArithmetic: Iterator[OpNodes.Arithmetic] = - astUp(allArithmeticTypes) - .map(new OpNodes.Arithmetic(_)) + def inArithmetic: Iterator[OpNodes.Arithmetic] = + astUp(allArithmeticTypes) + .map(new OpNodes.Arithmetic(_)) - def inArrayAccess: Iterator[OpNodes.ArrayAccess] = - astUp(allArrayAccessTypes) - .map(new OpNodes.ArrayAccess(_)) + def inArrayAccess: Iterator[OpNodes.ArrayAccess] = + astUp(allArrayAccessTypes) + .map(new OpNodes.ArrayAccess(_)) - def inFieldAccess: Iterator[OpNodes.FieldAccess] = - astUp(allFieldAccessTypes) - .map(new OpNodes.FieldAccess(_)) + def inFieldAccess: Iterator[OpNodes.FieldAccess] = + astUp(allFieldAccessTypes) + .map(new OpNodes.FieldAccess(_)) - private def astUp(strings: Set[String]): Iterator[Call] = - node.inAstMinusLeaf.isCall - .filter(x => strings.contains(x.name)) - -} + private def astUp(strings: Set[String]): Iterator[Call] = + node.inAstMinusLeaf.isCall + .filter(x => strings.contains(x.name)) +end OpAstNodeMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala index 78bf6873..f0a90fb9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala @@ -2,19 +2,17 @@ package io.shiftleft.semanticcpg.language.operatorextension.nodemethods import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Call, Expression} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.{OpNodes, allArrayAccessTypes} -class TargetMethods(val expr: Expression) extends AnyVal { +class TargetMethods(val expr: Expression) extends AnyVal: - def arrayAccess: Option[OpNodes.ArrayAccess] = - expr.ast.isCall - .collectFirst { case x if allArrayAccessTypes.contains(x.name) => x } - .map(new OpNodes.ArrayAccess(_)) + def arrayAccess: Option[OpNodes.ArrayAccess] = + expr.ast.isCall + .collectFirst { case x if allArrayAccessTypes.contains(x.name) => x } + .map(new OpNodes.ArrayAccess(_)) - def pointer: Option[Expression] = - Option(expr).collect { - case call: Call if call.name == Operators.indirection => call.argument(1) - } - -} + def pointer: Option[Expression] = + Option(expr).collect { + case call: Call if call.name == Operators.indirection => call.argument(1) + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala index f2b7cef8..8f5d9e12 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala @@ -2,57 +2,56 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.Operators -package object operatorextension { - - /** All operators that perform both assignments and arithmetic. - */ - val assignmentAndArithmetic: Set[String] = Set( - Operators.assignmentDivision, - Operators.assignmentExponentiation, - Operators.assignmentPlus, - Operators.assignmentMinus, - Operators.assignmentModulo, - Operators.assignmentMultiplication, - Operators.preIncrement, - Operators.preDecrement, - Operators.postIncrement, - Operators.postIncrement - ) - - /** All operators that carry out assignments. - */ - val allAssignmentTypes: Set[String] = Set( - Operators.assignment, - Operators.assignmentOr, - Operators.assignmentAnd, - Operators.assignmentXor, - Operators.assignmentArithmeticShiftRight, - Operators.assignmentLogicalShiftRight, - Operators.assignmentShiftLeft - ) ++ assignmentAndArithmetic - - /** All operators representing arithmetic. - */ - val allArithmeticTypes: Set[String] = Set( - Operators.addition, - Operators.subtraction, - Operators.division, - Operators.multiplication, - Operators.exponentiation, - Operators.modulo - ) ++ assignmentAndArithmetic - - /** All operators representing array accesses. - */ - val allArrayAccessTypes: Set[String] = Set( - Operators.computedMemberAccess, - Operators.indirectComputedMemberAccess, - Operators.indexAccess, - Operators.indirectIndexAccess - ) - - /** All operators representing direct or indirect accesses to fields of data structures - */ - val allFieldAccessTypes: Set[String] = Set(Operators.fieldAccess, Operators.indirectFieldAccess) - -} +package object operatorextension: + + /** All operators that perform both assignments and arithmetic. + */ + val assignmentAndArithmetic: Set[String] = Set( + Operators.assignmentDivision, + Operators.assignmentExponentiation, + Operators.assignmentPlus, + Operators.assignmentMinus, + Operators.assignmentModulo, + Operators.assignmentMultiplication, + Operators.preIncrement, + Operators.preDecrement, + Operators.postIncrement, + Operators.postIncrement + ) + + /** All operators that carry out assignments. + */ + val allAssignmentTypes: Set[String] = Set( + Operators.assignment, + Operators.assignmentOr, + Operators.assignmentAnd, + Operators.assignmentXor, + Operators.assignmentArithmeticShiftRight, + Operators.assignmentLogicalShiftRight, + Operators.assignmentShiftLeft + ) ++ assignmentAndArithmetic + + /** All operators representing arithmetic. + */ + val allArithmeticTypes: Set[String] = Set( + Operators.addition, + Operators.subtraction, + Operators.division, + Operators.multiplication, + Operators.exponentiation, + Operators.modulo + ) ++ assignmentAndArithmetic + + /** All operators representing array accesses. + */ + val allArrayAccessTypes: Set[String] = Set( + Operators.computedMemberAccess, + Operators.indirectComputedMemberAccess, + Operators.indexAccess, + Operators.indirectIndexAccess + ) + + /** All operators representing direct or indirect accesses to fields of data structures + */ + val allFieldAccessTypes: Set[String] = Set(Operators.fieldAccess, Operators.indirectFieldAccess) +end operatorextension diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala index 9f462586..d76df0e8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala @@ -1,283 +1,320 @@ package io.shiftleft.semanticcpg import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.traversal.NodeTraversalImplicits import io.shiftleft.semanticcpg.language.bindingextension.{ - MethodTraversal => BindingMethodTraversal, - TypeDeclTraversal => BindingTypeDeclTraversal + MethodTraversal as BindingMethodTraversal, + TypeDeclTraversal as BindingTypeDeclTraversal } import io.shiftleft.semanticcpg.language.callgraphextension.{CallTraversal, MethodTraversal} -import io.shiftleft.semanticcpg.language.dotextension.{AstNodeDot, CfgNodeDot, InterproceduralNodeDot} -import io.shiftleft.semanticcpg.language.nodemethods._ +import io.shiftleft.semanticcpg.language.dotextension.{ + AstNodeDot, + CfgNodeDot, + InterproceduralNodeDot +} +import io.shiftleft.semanticcpg.language.nodemethods.* import io.shiftleft.semanticcpg.language.types.expressions.generalizations.{ - AstNodeTraversal, - CfgNodeTraversal, - DeclarationTraversal, - ExpressionTraversal + AstNodeTraversal, + CfgNodeTraversal, + DeclarationTraversal, + ExpressionTraversal } -import io.shiftleft.semanticcpg.language.types.expressions.{CallTraversal => OriginalCall, _} -import io.shiftleft.semanticcpg.language.types.propertyaccessors._ -import io.shiftleft.semanticcpg.language.types.structure.{MethodTraversal => OriginalMethod, _} +import io.shiftleft.semanticcpg.language.types.expressions.{CallTraversal as OriginalCall, *} +import io.shiftleft.semanticcpg.language.types.propertyaccessors.* +import io.shiftleft.semanticcpg.language.types.structure.{MethodTraversal as OriginalMethod, *} import overflowdb.NodeOrDetachedNode /** Language for traversing the code property graph * - * Implicit conversions to specific steps, based on the node at hand. Automatically in scope when using anything in the - * `steps` package, e.g. `Steps` + * Implicit conversions to specific steps, based on the node at hand. Automatically in scope when + * using anything in the `steps` package, e.g. `Steps` */ -package object language extends operatorextension.Implicits with LowPrioImplicits with NodeTraversalImplicits { - // Implicit conversions from generated node types. We use these to add methods - // to generated node types. - - implicit def cfgNodeToAsNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toExtendedNode(node: NodeOrDetachedNode): NodeMethods = new NodeMethods(node) - implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = new StoredNodeMethods(node) - implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) - implicit def toExpressionMethods(node: Expression): ExpressionMethods = new ExpressionMethods(node) - implicit def toMethodMethods(node: Method): MethodMethods = new MethodMethods(node) - implicit def toMethodReturnMethods(node: MethodReturn): MethodReturnMethods = new MethodReturnMethods(node) - implicit def toCallMethods(node: Call): CallMethods = new CallMethods(node) - implicit def toMethodParamInMethods(node: MethodParameterIn): MethodParameterInMethods = - new MethodParameterInMethods(node) - implicit def toMethodParamOutMethods(node: MethodParameterOut): MethodParameterOutMethods = - new MethodParameterOutMethods(node) - implicit def toIdentifierMethods(node: Identifier): IdentifierMethods = new IdentifierMethods(node) - implicit def toLiteralMethods(node: Literal): LiteralMethods = new LiteralMethods(node) - implicit def toLocalMethods(node: Local): LocalMethods = new LocalMethods(node) - implicit def toMethodRefMethods(node: MethodRef): MethodRefMethods = new MethodRefMethods(node) - - // Implicit conversions from Step[NodeType, Label] to corresponding Step classes. - // If you introduce a new Step-type, that is, one that inherits from `Steps[NodeType]`, - // then you need to add an implicit conversion from `Steps[NodeType]` to your type - // here. - - implicit def singleToTypeTrav[A <: Type](a: A): TypeTraversal = - new TypeTraversal(Iterator.single(a)) - implicit def iterOnceToTypeTrav[A <: Type](a: IterableOnce[A]): TypeTraversal = - new TypeTraversal(a.iterator) - - implicit def singleToTypeDeclTrav[A <: TypeDecl](a: A): TypeDeclTraversal = - new TypeDeclTraversal(Iterator.single(a)) - implicit def iterOnceToTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): TypeDeclTraversal = - new TypeDeclTraversal(a.iterator) - - implicit def iterOnceToOriginalCallTrav[A <: Call](a: IterableOnce[A]): OriginalCall = - new OriginalCall(a.iterator) - - implicit def singleToControlStructureTrav[A <: ControlStructure](a: A): ControlStructureTraversal = - new ControlStructureTraversal(Iterator.single(a)) - implicit def iterOnceToControlStructureTrav[A <: ControlStructure](a: IterableOnce[A]): ControlStructureTraversal = - new ControlStructureTraversal(a.iterator) - - implicit def singleToIdentifierTrav[A <: Identifier](a: A): IdentifierTraversal = - new IdentifierTraversal(Iterator.single(a)) - implicit def iterOnceToIdentifierTrav[A <: Identifier](a: IterableOnce[A]): IdentifierTraversal = - new IdentifierTraversal(a.iterator) - - implicit def singleToAnnotationTrav[A <: Annotation](a: A): AnnotationTraversal = - new AnnotationTraversal(Iterator.single(a)) - implicit def iterOnceToAnnotationTrav[A <: Annotation](a: IterableOnce[A]): AnnotationTraversal = - new AnnotationTraversal(a.iterator) - - implicit def singleToDependencyTrav[A <: Dependency](a: A): DependencyTraversal = - new DependencyTraversal(Iterator.single(a)) - - implicit def iterToDependencyTrav[A <: Dependency](a: IterableOnce[A]): DependencyTraversal = - new DependencyTraversal(a.iterator) - - implicit def singleToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( - a: A - ): AnnotationParameterAssignTraversal = - new AnnotationParameterAssignTraversal(Iterator.single(a)) - implicit def iterOnceToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( - a: IterableOnce[A] - ): AnnotationParameterAssignTraversal = - new AnnotationParameterAssignTraversal(a.iterator) - - implicit def toMember(traversal: IterableOnce[Member]): MemberTraversal = new MemberTraversal(traversal.iterator) - implicit def toLocal(traversal: IterableOnce[Local]): LocalTraversal = new LocalTraversal(traversal.iterator) - implicit def toMethod(traversal: IterableOnce[Method]): OriginalMethod = new OriginalMethod(traversal.iterator) - - implicit def singleToMethodParameterInTrav[A <: MethodParameterIn](a: A): MethodParameterTraversal = - new MethodParameterTraversal(Iterator.single(a)) - implicit def iterOnceToMethodParameterInTrav[A <: MethodParameterIn](a: IterableOnce[A]): MethodParameterTraversal = - new MethodParameterTraversal(a.iterator) - - implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A): MethodParameterOutTraversal = - new MethodParameterOutTraversal(Iterator.single(a)) - implicit def iterOnceToMethodParameterOutTrav[A <: MethodParameterOut]( - a: IterableOnce[A] - ): MethodParameterOutTraversal = - new MethodParameterOutTraversal(a.iterator) - - implicit def iterOnceToMethodReturnTrav[A <: MethodReturn](a: IterableOnce[A]): MethodReturnTraversal = - new MethodReturnTraversal(a.iterator) - - implicit def singleToNamespaceTrav[A <: Namespace](a: A): NamespaceTraversal = - new NamespaceTraversal(Iterator.single(a)) - implicit def iterOnceToNamespaceTrav[A <: Namespace](a: IterableOnce[A]): NamespaceTraversal = - new NamespaceTraversal(a.iterator) - - implicit def singleToNamespaceBlockTrav[A <: NamespaceBlock](a: A): NamespaceBlockTraversal = - new NamespaceBlockTraversal(Iterator.single(a)) - implicit def iterOnceToNamespaceBlockTrav[A <: NamespaceBlock](a: IterableOnce[A]): NamespaceBlockTraversal = - new NamespaceBlockTraversal(a.iterator) - - implicit def singleToFileTrav[A <: File](a: A): FileTraversal = - new FileTraversal(Iterator.single(a)) - implicit def iterOnceToFileTrav[A <: File](a: IterableOnce[A]): FileTraversal = - new FileTraversal(a.iterator) - - implicit def singleToImportTrav[A <: Import](a: A): ImportTraversal = - new ImportTraversal(Iterator.single(a)) - - implicit def iterToImportTrav[A <: Import](a: IterableOnce[A]): ImportTraversal = - new ImportTraversal(a.iterator) - - // Call graph extension - implicit def singleToMethodTravCallGraphExt[A <: Method](a: A): MethodTraversal = - new MethodTraversal(Iterator.single(a)) - implicit def iterOnceToMethodTravCallGraphExt[A <: Method](a: IterableOnce[A]): MethodTraversal = - new MethodTraversal(a.iterator) - implicit def singleToCallTrav[A <: Call](a: A): CallTraversal = - new CallTraversal(Iterator.single(a)) - implicit def iterOnceToCallTrav[A <: Call](a: IterableOnce[A]): CallTraversal = - new CallTraversal(a.iterator) - // / Call graph extension - - // Binding extensions - implicit def singleToBindingMethodTrav[A <: Method](a: A): BindingMethodTraversal = - new BindingMethodTraversal(Iterator.single(a)) - implicit def iterOnceToBindingMethodTrav[A <: Method](a: IterableOnce[A]): BindingMethodTraversal = - new BindingMethodTraversal(a.iterator) - - implicit def singleToBindingTypeDeclTrav[A <: TypeDecl](a: A): BindingTypeDeclTraversal = - new BindingTypeDeclTraversal(Iterator.single(a)) - implicit def iterOnceToBindingTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): BindingTypeDeclTraversal = - new BindingTypeDeclTraversal(a.iterator) - - implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = - new AstNodeDot(Iterator.single(a)) - implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = - new AstNodeDot(a.iterator) - - implicit def singleToCfgNodeDot[A <: Method](a: A): CfgNodeDot = - new CfgNodeDot(Iterator.single(a)) - implicit def iterOnceToCfgNodeDot[A <: Method](a: IterableOnce[A]): CfgNodeDot = - new CfgNodeDot(a.iterator) - - implicit def graphToInterproceduralDot(cpg: Cpg): InterproceduralNodeDot = - new InterproceduralNodeDot(cpg) - - /** Warning: implicitly lifting `Node -> Traversal` opens a broad space with a lot of accidental complexity and is - * considered a historical accident. We only keep it around because we want to preserve `reachableBy(Node*)`, which - * unfortunately (due to type erasure) can't be an overload of `reachableBy(Traversal*)`. - * - * In most places you should explicitly call `Iterator.single` instead of relying on this implicit. - */ - implicit def toTraversal[NodeType <: StoredNode](node: NodeType): Iterator[NodeType] = - Iterator.single(node) - - implicit def iterableOnceToSteps[A](iterableOnce: IterableOnce[A]): Steps[A] = - new Steps(iterableOnce.iterator) - - implicit def traversalToSteps[A](trav: Iterator[A]): Steps[A] = - new Steps(trav) - implicit def iterOnceToNodeSteps[A <: StoredNode](a: IterableOnce[A]): NodeSteps[A] = - new NodeSteps[A](a.iterator) - - implicit def toNewNodeTrav[NodeType <: NewNode](trav: Iterator[NodeType]): NewNodeSteps[NodeType] = - new NewNodeSteps[NodeType](trav) - - implicit def toNodeTypeStarters(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) - implicit def toTagTraversal(trav: Iterator[Tag]): TagTraversal = new TagTraversal(trav) - - // ~ EvalType accessors - implicit def singleToEvalTypeAccessorsLocal[A <: Local](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsLocal[A <: Local](a: IterableOnce[A]): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMember[A <: Member](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMember[A <: Member](a: IterableOnce[A]): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMethod[A <: Method](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMethod[A <: Method](a: IterableOnce[A]): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsParameterIn[A <: MethodParameterIn](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsParameterIn[A <: MethodParameterIn]( - a: IterableOnce[A] - ): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsParameterOut[A <: MethodParameterOut](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsParameterOut[A <: MethodParameterOut]( - a: IterableOnce[A] - ): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: IterableOnce[A]): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsExpression[A <: Expression](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsExpression[A <: Expression](a: IterableOnce[A]): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - // EvalType accessors ~ - - // ~ Modifier accessors - implicit def singleToModifierAccessorsMember[A <: Member](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsMember[A <: Member](a: IterableOnce[A]): ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - - implicit def singleToModifierAccessorsMethod[A <: Method](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsMethod[A <: Method](a: IterableOnce[A]): ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - - implicit def singleToModifierAccessorsTypeDecl[A <: TypeDecl](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsTypeDecl[A <: TypeDecl](a: IterableOnce[A]): ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - // Modifier accessors ~ - - implicit class NewNodeTypeDeco[NodeType <: NewNode](val node: NodeType) extends AnyVal { - - /** Start a new traversal from this node +package object language extends operatorextension.Implicits with LowPrioImplicits + with NodeTraversalImplicits: + // Implicit conversions from generated node types. We use these to add methods + // to generated node types. + + implicit def cfgNodeToAsNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toExtendedNode(node: NodeOrDetachedNode): NodeMethods = new NodeMethods(node) + implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = + new StoredNodeMethods(node) + implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) + implicit def toExpressionMethods(node: Expression): ExpressionMethods = + new ExpressionMethods(node) + implicit def toMethodMethods(node: Method): MethodMethods = new MethodMethods(node) + implicit def toMethodReturnMethods(node: MethodReturn): MethodReturnMethods = + new MethodReturnMethods(node) + implicit def toCallMethods(node: Call): CallMethods = new CallMethods(node) + implicit def toMethodParamInMethods(node: MethodParameterIn): MethodParameterInMethods = + new MethodParameterInMethods(node) + implicit def toMethodParamOutMethods(node: MethodParameterOut): MethodParameterOutMethods = + new MethodParameterOutMethods(node) + implicit def toIdentifierMethods(node: Identifier): IdentifierMethods = + new IdentifierMethods(node) + implicit def toLiteralMethods(node: Literal): LiteralMethods = new LiteralMethods(node) + implicit def toLocalMethods(node: Local): LocalMethods = new LocalMethods(node) + implicit def toMethodRefMethods(node: MethodRef): MethodRefMethods = new MethodRefMethods(node) + + // Implicit conversions from Step[NodeType, Label] to corresponding Step classes. + // If you introduce a new Step-type, that is, one that inherits from `Steps[NodeType]`, + // then you need to add an implicit conversion from `Steps[NodeType]` to your type + // here. + + implicit def singleToTypeTrav[A <: Type](a: A): TypeTraversal = + new TypeTraversal(Iterator.single(a)) + implicit def iterOnceToTypeTrav[A <: Type](a: IterableOnce[A]): TypeTraversal = + new TypeTraversal(a.iterator) + + implicit def singleToTypeDeclTrav[A <: TypeDecl](a: A): TypeDeclTraversal = + new TypeDeclTraversal(Iterator.single(a)) + implicit def iterOnceToTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): TypeDeclTraversal = + new TypeDeclTraversal(a.iterator) + + implicit def iterOnceToOriginalCallTrav[A <: Call](a: IterableOnce[A]): OriginalCall = + new OriginalCall(a.iterator) + + implicit def singleToControlStructureTrav[A <: ControlStructure](a: A) + : ControlStructureTraversal = + new ControlStructureTraversal(Iterator.single(a)) + implicit def iterOnceToControlStructureTrav[A <: ControlStructure](a: IterableOnce[A]) + : ControlStructureTraversal = + new ControlStructureTraversal(a.iterator) + + implicit def singleToIdentifierTrav[A <: Identifier](a: A): IdentifierTraversal = + new IdentifierTraversal(Iterator.single(a)) + implicit def iterOnceToIdentifierTrav[A <: Identifier](a: IterableOnce[A]) + : IdentifierTraversal = + new IdentifierTraversal(a.iterator) + + implicit def singleToAnnotationTrav[A <: Annotation](a: A): AnnotationTraversal = + new AnnotationTraversal(Iterator.single(a)) + implicit def iterOnceToAnnotationTrav[A <: Annotation](a: IterableOnce[A]) + : AnnotationTraversal = + new AnnotationTraversal(a.iterator) + + implicit def singleToDependencyTrav[A <: Dependency](a: A): DependencyTraversal = + new DependencyTraversal(Iterator.single(a)) + + implicit def iterToDependencyTrav[A <: Dependency](a: IterableOnce[A]): DependencyTraversal = + new DependencyTraversal(a.iterator) + + implicit def singleToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( + a: A + ): AnnotationParameterAssignTraversal = + new AnnotationParameterAssignTraversal(Iterator.single(a)) + implicit def iterOnceToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( + a: IterableOnce[A] + ): AnnotationParameterAssignTraversal = + new AnnotationParameterAssignTraversal(a.iterator) + + implicit def toMember(traversal: IterableOnce[Member]): MemberTraversal = + new MemberTraversal(traversal.iterator) + implicit def toLocal(traversal: IterableOnce[Local]): LocalTraversal = + new LocalTraversal(traversal.iterator) + implicit def toMethod(traversal: IterableOnce[Method]): OriginalMethod = + new OriginalMethod(traversal.iterator) + + implicit def singleToMethodParameterInTrav[A <: MethodParameterIn](a: A) + : MethodParameterTraversal = + new MethodParameterTraversal(Iterator.single(a)) + implicit def iterOnceToMethodParameterInTrav[A <: MethodParameterIn](a: IterableOnce[A]) + : MethodParameterTraversal = + new MethodParameterTraversal(a.iterator) + + implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A) + : MethodParameterOutTraversal = + new MethodParameterOutTraversal(Iterator.single(a)) + implicit def iterOnceToMethodParameterOutTrav[A <: MethodParameterOut]( + a: IterableOnce[A] + ): MethodParameterOutTraversal = + new MethodParameterOutTraversal(a.iterator) + + implicit def iterOnceToMethodReturnTrav[A <: MethodReturn](a: IterableOnce[A]) + : MethodReturnTraversal = + new MethodReturnTraversal(a.iterator) + + implicit def singleToNamespaceTrav[A <: Namespace](a: A): NamespaceTraversal = + new NamespaceTraversal(Iterator.single(a)) + implicit def iterOnceToNamespaceTrav[A <: Namespace](a: IterableOnce[A]): NamespaceTraversal = + new NamespaceTraversal(a.iterator) + + implicit def singleToNamespaceBlockTrav[A <: NamespaceBlock](a: A): NamespaceBlockTraversal = + new NamespaceBlockTraversal(Iterator.single(a)) + implicit def iterOnceToNamespaceBlockTrav[A <: NamespaceBlock](a: IterableOnce[A]) + : NamespaceBlockTraversal = + new NamespaceBlockTraversal(a.iterator) + + implicit def singleToFileTrav[A <: File](a: A): FileTraversal = + new FileTraversal(Iterator.single(a)) + implicit def iterOnceToFileTrav[A <: File](a: IterableOnce[A]): FileTraversal = + new FileTraversal(a.iterator) + + implicit def singleToImportTrav[A <: Import](a: A): ImportTraversal = + new ImportTraversal(Iterator.single(a)) + + implicit def iterToImportTrav[A <: Import](a: IterableOnce[A]): ImportTraversal = + new ImportTraversal(a.iterator) + + // Call graph extension + implicit def singleToMethodTravCallGraphExt[A <: Method](a: A): MethodTraversal = + new MethodTraversal(Iterator.single(a)) + implicit def iterOnceToMethodTravCallGraphExt[A <: Method](a: IterableOnce[A]) + : MethodTraversal = + new MethodTraversal(a.iterator) + implicit def singleToCallTrav[A <: Call](a: A): CallTraversal = + new CallTraversal(Iterator.single(a)) + implicit def iterOnceToCallTrav[A <: Call](a: IterableOnce[A]): CallTraversal = + new CallTraversal(a.iterator) + // / Call graph extension + + // Binding extensions + implicit def singleToBindingMethodTrav[A <: Method](a: A): BindingMethodTraversal = + new BindingMethodTraversal(Iterator.single(a)) + implicit def iterOnceToBindingMethodTrav[A <: Method](a: IterableOnce[A]) + : BindingMethodTraversal = + new BindingMethodTraversal(a.iterator) + + implicit def singleToBindingTypeDeclTrav[A <: TypeDecl](a: A): BindingTypeDeclTraversal = + new BindingTypeDeclTraversal(Iterator.single(a)) + implicit def iterOnceToBindingTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]) + : BindingTypeDeclTraversal = + new BindingTypeDeclTraversal(a.iterator) + + implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = + new AstNodeDot(Iterator.single(a)) + implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = + new AstNodeDot(a.iterator) + + implicit def singleToCfgNodeDot[A <: Method](a: A): CfgNodeDot = + new CfgNodeDot(Iterator.single(a)) + implicit def iterOnceToCfgNodeDot[A <: Method](a: IterableOnce[A]): CfgNodeDot = + new CfgNodeDot(a.iterator) + + implicit def graphToInterproceduralDot(cpg: Cpg): InterproceduralNodeDot = + new InterproceduralNodeDot(cpg) + + /** Warning: implicitly lifting `Node -> Traversal` opens a broad space with a lot of accidental + * complexity and is considered a historical accident. We only keep it around because we want + * to preserve `reachableBy(Node*)`, which unfortunately (due to type erasure) can't be an + * overload of `reachableBy(Traversal*)`. + * + * In most places you should explicitly call `Iterator.single` instead of relying on this + * implicit. */ - def start: Iterator[NodeType] = - Iterator.single(node) - } - - implicit def toExpression[A <: Expression](a: IterableOnce[A]): ExpressionTraversal[A] = - new ExpressionTraversal[A](a.iterator) -} - -trait LowPrioImplicits extends overflowdb.traversal.Implicits { - implicit def singleToCfgNodeTraversal[A <: CfgNode](a: A): CfgNodeTraversal[A] = - new CfgNodeTraversal[A](Iterator.single(a)) - implicit def iterOnceToCfgNodeTraversal[A <: CfgNode](a: IterableOnce[A]): CfgNodeTraversal[A] = - new CfgNodeTraversal[A](a.iterator) - - implicit def singleToAstNodeTraversal[A <: AstNode](a: A): AstNodeTraversal[A] = - new AstNodeTraversal[A](Iterator.single(a)) - implicit def iterOnceToAstNodeTraversal[A <: AstNode](a: IterableOnce[A]): AstNodeTraversal[A] = - new AstNodeTraversal[A](a.iterator) - - implicit def singleToDeclarationNodeTraversal[A <: Declaration](a: A): DeclarationTraversal[A] = - new DeclarationTraversal[A](Iterator.single(a)) - implicit def iterOnceToDeclarationNodeTraversal[A <: Declaration](a: IterableOnce[A]): DeclarationTraversal[A] = - new DeclarationTraversal[A](a.iterator) -} + implicit def toTraversal[NodeType <: StoredNode](node: NodeType): Iterator[NodeType] = + Iterator.single(node) + + implicit def iterableOnceToSteps[A](iterableOnce: IterableOnce[A]): Steps[A] = + new Steps(iterableOnce.iterator) + + implicit def traversalToSteps[A](trav: Iterator[A]): Steps[A] = + new Steps(trav) + implicit def iterOnceToNodeSteps[A <: StoredNode](a: IterableOnce[A]): NodeSteps[A] = + new NodeSteps[A](a.iterator) + + implicit def toNewNodeTrav[NodeType <: NewNode](trav: Iterator[NodeType]) + : NewNodeSteps[NodeType] = + new NewNodeSteps[NodeType](trav) + + implicit def toNodeTypeStarters(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) + implicit def toTagTraversal(trav: Iterator[Tag]): TagTraversal = new TagTraversal(trav) + + // ~ EvalType accessors + implicit def singleToEvalTypeAccessorsLocal[A <: Local](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsLocal[A <: Local](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMember[A <: Member](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMember[A <: Member](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMethod[A <: Method](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMethod[A <: Method](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsParameterIn[A <: MethodParameterIn](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsParameterIn[A <: MethodParameterIn]( + a: IterableOnce[A] + ): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsParameterOut[A <: MethodParameterOut](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsParameterOut[A <: MethodParameterOut]( + a: IterableOnce[A] + ): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsExpression[A <: Expression](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsExpression[A <: Expression](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + // EvalType accessors ~ + + // ~ Modifier accessors + implicit def singleToModifierAccessorsMember[A <: Member](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsMember[A <: Member](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + + implicit def singleToModifierAccessorsMethod[A <: Method](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsMethod[A <: Method](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + + implicit def singleToModifierAccessorsTypeDecl[A <: TypeDecl](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsTypeDecl[A <: TypeDecl](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + // Modifier accessors ~ + + implicit class NewNodeTypeDeco[NodeType <: NewNode](val node: NodeType) extends AnyVal: + + /** Start a new traversal from this node + */ + def start: Iterator[NodeType] = + Iterator.single(node) + + implicit def toExpression[A <: Expression](a: IterableOnce[A]): ExpressionTraversal[A] = + new ExpressionTraversal[A](a.iterator) +end language + +trait LowPrioImplicits extends overflowdb.traversal.Implicits: + implicit def singleToCfgNodeTraversal[A <: CfgNode](a: A): CfgNodeTraversal[A] = + new CfgNodeTraversal[A](Iterator.single(a)) + implicit def iterOnceToCfgNodeTraversal[A <: CfgNode](a: IterableOnce[A]): CfgNodeTraversal[A] = + new CfgNodeTraversal[A](a.iterator) + + implicit def singleToAstNodeTraversal[A <: AstNode](a: A): AstNodeTraversal[A] = + new AstNodeTraversal[A](Iterator.single(a)) + implicit def iterOnceToAstNodeTraversal[A <: AstNode](a: IterableOnce[A]): AstNodeTraversal[A] = + new AstNodeTraversal[A](a.iterator) + + implicit def singleToDeclarationNodeTraversal[A <: Declaration](a: A): DeclarationTraversal[A] = + new DeclarationTraversal[A](Iterator.single(a)) + implicit def iterOnceToDeclarationNodeTraversal[A <: Declaration](a: IterableOnce[A]) + : DeclarationTraversal[A] = + new DeclarationTraversal[A](a.iterator) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala index ef233713..49f41804 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala @@ -5,38 +5,37 @@ import io.shiftleft.semanticcpg.language.* /** A call site */ -class CallTraversal(val traversal: Iterator[Call]) extends AnyVal { - - /** Only statically dispatched calls - */ - def isStatic: Iterator[Call] = - traversal.dispatchType("STATIC_DISPATCH") - - /** Only dynamically dispatched calls - */ - def isDynamic: Iterator[Call] = - traversal.dispatchType("DYNAMIC_DISPATCH") - - /** The receiver of a call if the call has a receiver associated. - */ - def receiver: Iterator[Expression] = - traversal.flatMap(_.receiver) - - /** Arguments of the call - */ - def argument: Iterator[Expression] = - traversal.flatMap(_.argument) - - /** `i'th` arguments of the call - */ - def argument(i: Integer): Iterator[Expression] = - traversal.flatMap(_.arguments(i)) - - /** To formal method return parameter - */ - def toMethodReturn(implicit callResolver: ICallResolver): Iterator[MethodReturn] = - traversal - .flatMap(callResolver.getCalledMethodsAsTraversal) - .flatMap(_.methodReturn) - -} +class CallTraversal(val traversal: Iterator[Call]) extends AnyVal: + + /** Only statically dispatched calls + */ + def isStatic: Iterator[Call] = + traversal.dispatchType("STATIC_DISPATCH") + + /** Only dynamically dispatched calls + */ + def isDynamic: Iterator[Call] = + traversal.dispatchType("DYNAMIC_DISPATCH") + + /** The receiver of a call if the call has a receiver associated. + */ + def receiver: Iterator[Expression] = + traversal.flatMap(_.receiver) + + /** Arguments of the call + */ + def argument: Iterator[Expression] = + traversal.flatMap(_.argument) + + /** `i'th` arguments of the call + */ + def argument(i: Integer): Iterator[Expression] = + traversal.flatMap(_.arguments(i)) + + /** To formal method return parameter + */ + def toMethodReturn(implicit callResolver: ICallResolver): Iterator[MethodReturn] = + traversal + .flatMap(callResolver.getCalledMethodsAsTraversal) + .flatMap(_.methodReturn) +end CallTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala index 6d70ded8..1d9c750d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala @@ -5,72 +5,70 @@ import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, Properti import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc -object ControlStructureTraversal { - val secondChildIndex = 2 - val thirdChildIndex = 3 -} +object ControlStructureTraversal: + val secondChildIndex = 2 + val thirdChildIndex = 3 -class ControlStructureTraversal(val traversal: Iterator[ControlStructure]) extends AnyVal { - import ControlStructureTraversal.* +class ControlStructureTraversal(val traversal: Iterator[ControlStructure]) extends AnyVal: + import ControlStructureTraversal.* - @Doc(info = "The condition associated with this control structure") - def condition: Iterator[Expression] = - traversal.flatMap(_.conditionOut).collectAll[Expression] + @Doc(info = "The condition associated with this control structure") + def condition: Iterator[Expression] = + traversal.flatMap(_.conditionOut).collectAll[Expression] - @Doc(info = "Control structures where condition.code matches regex") - def condition(regex: String): Iterator[ControlStructure] = - traversal.where(_.condition.code(regex)) + @Doc(info = "Control structures where condition.code matches regex") + def condition(regex: String): Iterator[ControlStructure] = + traversal.where(_.condition.code(regex)) - @Doc(info = "Sub tree taken when condition evaluates to true") - def whenTrue: Iterator[AstNode] = - traversal.out.has(Properties.ORDER, secondChildIndex: Int).cast[AstNode] + @Doc(info = "Sub tree taken when condition evaluates to true") + def whenTrue: Iterator[AstNode] = + traversal.out.has(Properties.ORDER, secondChildIndex: Int).cast[AstNode] - @Doc(info = "Sub tree taken when condition evaluates to false") - def whenFalse: Iterator[AstNode] = - traversal.out.has(Properties.ORDER, thirdChildIndex).cast[AstNode] + @Doc(info = "Sub tree taken when condition evaluates to false") + def whenFalse: Iterator[AstNode] = + traversal.out.has(Properties.ORDER, thirdChildIndex).cast[AstNode] - @Doc(info = "Only `Try` control structures") - def isTry: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.TRY) + @Doc(info = "Only `Try` control structures") + def isTry: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.TRY) - @Doc(info = "Only `If` control structures") - def isIf: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.IF) + @Doc(info = "Only `If` control structures") + def isIf: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.IF) - @Doc(info = "Only `Else` control structures") - def isElse: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.ELSE) + @Doc(info = "Only `Else` control structures") + def isElse: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.ELSE) - @Doc(info = "Only `Switch` control structures") - def isSwitch: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.SWITCH) + @Doc(info = "Only `Switch` control structures") + def isSwitch: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.SWITCH) - @Doc(info = "Only `Do` control structures") - def isDo: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.DO) + @Doc(info = "Only `Do` control structures") + def isDo: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.DO) - @Doc(info = "Only `For` control structures") - def isFor: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.FOR) + @Doc(info = "Only `For` control structures") + def isFor: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.FOR) - @Doc(info = "Only `While` control structures") - def isWhile: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.WHILE) + @Doc(info = "Only `While` control structures") + def isWhile: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.WHILE) - @Doc(info = "Only `Goto` control structures") - def isGoto: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.GOTO) + @Doc(info = "Only `Goto` control structures") + def isGoto: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.GOTO) - @Doc(info = "Only `Break` control structures") - def isBreak: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.BREAK) + @Doc(info = "Only `Break` control structures") + def isBreak: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.BREAK) - @Doc(info = "Only `Continue` control structures") - def isContinue: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.CONTINUE) + @Doc(info = "Only `Continue` control structures") + def isContinue: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.CONTINUE) - @Doc(info = "Only `Throw` control structures") - def isThrow: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.THROW) - -} + @Doc(info = "Only `Throw` control structures") + def isThrow: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.THROW) +end ControlStructureTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala index 8b11e563..bfc40a6c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala @@ -5,12 +5,9 @@ import io.shiftleft.semanticcpg.language.toTraversalSugarExt /** An identifier, e.g., an instance of a local variable, or a temporary variable */ -class IdentifierTraversal(val traversal: Iterator[Identifier]) extends AnyVal { +class IdentifierTraversal(val traversal: Iterator[Identifier]) extends AnyVal: - /** Traverse to all declarations of this identifier - */ - def refsTo: Iterator[Declaration] = { - traversal.flatMap(_.refOut).cast[Declaration] - } - -} + /** Traverse to all declarations of this identifier + */ + def refsTo: Iterator[Declaration] = + traversal.flatMap(_.refOut).cast[Declaration] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala index f28be54d..1f64a3d5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala @@ -7,211 +7,219 @@ import overflowdb.traversal.help import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[AstNode]) -class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { - - /** Nodes of the AST rooted in this node, including the node itself. - */ - @Doc(info = "All nodes of the abstract syntax tree") - def ast: Iterator[AstNode] = { - traversal.repeat(_.out(EdgeTypes.AST))(_.emit).cast[AstNode] - } - - /** All nodes of the abstract syntax tree rooted in this node, which match `predicate`. Equivalent of `match` in the - * original CPG paper. - */ - def ast(predicate: AstNode => Boolean): Iterator[AstNode] = - ast.filter(predicate) - - def containsCallTo(regex: String): Iterator[A] = - traversal.filter(_.ast.isCall.name(regex).nonEmpty) - - @Doc(info = "Depth of the abstract syntax tree") - def depth: Iterator[Int] = - traversal.map(_.depth) - - def depth(p: AstNode => Boolean): Iterator[Int] = - traversal.map(_.depth(p)) - - def isCallTo(regex: String): Iterator[Call] = - isCall.name(regex) - - /** Nodes of the AST rooted in this node, minus the node itself - */ - def astMinusRoot: Iterator[AstNode] = - traversal.repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst).cast[AstNode] - - /** Direct children of node in the AST. Siblings are ordered by their `order` fields - */ - def astChildren: Iterator[AstNode] = - traversal.flatMap(_.astChildren).sortBy(_.order).iterator - - /** Parent AST node - */ - def astParent: Iterator[AstNode] = - traversal.in(EdgeTypes.AST).cast[AstNode] - - /** Siblings of this node in the AST, ordered by their `order` fields - */ - def astSiblings: Iterator[AstNode] = - traversal.flatMap(_.astSiblings) - - /** Traverses up the AST and returns the first block node. - */ - def parentBlock: Iterator[Block] = - traversal.repeat(_.in(EdgeTypes.AST))(_.emit.until(_.hasLabel(NodeTypes.BLOCK))).collectAll[Block] - - /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached - */ - def inAst: Iterator[AstNode] = - inAst(null) - - /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached, minus this node - */ - def inAstMinusLeaf: Iterator[AstNode] = - inAstMinusLeaf(null) - - /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is reached - */ - def inAst(root: AstNode): Iterator[AstNode] = - traversal - .repeat(_.in(EdgeTypes.AST))( - _.emit - .until(_.or(_.hasLabel(NodeTypes.METHOD), _.filter(n => root != null && root == n))) - ) - .cast[AstNode] - - /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is reached, minus this - * node - */ - def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = - traversal - .repeat(_.in(EdgeTypes.AST))( - _.emitAllButFirst - .until(_.or(_.hasLabel(NodeTypes.METHOD), _.filter(n => root != null && root == n))) - ) - .cast[AstNode] - - /** Traverse only to those AST nodes that are also control flow graph nodes - */ - def isCfgNode: Iterator[CfgNode] = - traversal.collectAll[CfgNode] - - def isAnnotation: Iterator[Annotation] = - traversal.collectAll[Annotation] - - def isAnnotationLiteral: Iterator[AnnotationLiteral] = - traversal.collectAll[AnnotationLiteral] - - def isArrayInitializer: Iterator[ArrayInitializer] = - traversal.collectAll[ArrayInitializer] - - /** Traverse only to those AST nodes that are blocks - */ - def isBlock: Iterator[Block] = - traversal.collectAll[Block] - - /** Traverse only to those AST nodes that are control structures - */ - def isControlStructure: Iterator[ControlStructure] = - traversal.collectAll[ControlStructure] - - /** Traverse only to AST nodes that are expressions - */ - def isExpression: Iterator[Expression] = - traversal.collectAll[Expression] - - /** Traverse only to AST nodes that are calls - */ - def isCall: Iterator[Call] = - traversal.collectAll[Call] - - /** Cast to call if applicable and filter on call code `calleeRegex` - */ - def isCall(calleeRegex: String): Iterator[Call] = - isCall.where(_.code(calleeRegex)) - - /** Traverse only to AST nodes that are literals - */ - def isLiteral: Iterator[Literal] = - traversal.collectAll[Literal] - - def isLocal: Iterator[Local] = - traversal.collectAll[Local] - - /** Traverse only to AST nodes that are identifier - */ - def isIdentifier: Iterator[Identifier] = - traversal.collectAll[Identifier] - - /** Traverse only to AST nodes that are IMPORT nodes - */ - def isImport: Iterator[Import] = - traversal.collectAll[Import] - - /** Traverse only to FILE AST nodes - */ - def isFile: Iterator[File] = - traversal.collectAll[File] - - /** Traverse only to AST nodes that are field identifier - */ - def isFieldIdentifier: Iterator[FieldIdentifier] = - traversal.collectAll[FieldIdentifier] - - /** Traverse only to AST nodes that are return nodes - */ - def isReturn: Iterator[Return] = - traversal.collectAll[Return] - - /** Traverse only to AST nodes that are MEMBER - */ - def isMember: Iterator[Member] = - traversal.collectAll[Member] - - /** Traverse only to AST nodes that are method reference - */ - def isMethodRef: Iterator[MethodRef] = - traversal.collectAll[MethodRef] - - /** Traverse only to AST nodes that are type reference - */ - def isTypeRef: Iterator[TypeRef] = - traversal.collectAll[TypeRef] - - /** Traverse only to AST nodes that are METHOD - */ - def isMethod: Iterator[Method] = - traversal.collectAll[Method] - - /** Traverse only to AST nodes that are MODIFIER - */ - def isModifier: Iterator[Modifier] = - traversal.collectAll[Modifier] - - /** Traverse only to AST nodes that are NAMESPACE_BLOCK - */ - def isNamespaceBlock: Iterator[NamespaceBlock] = - traversal.collectAll[NamespaceBlock] - - /** Traverse only to AST nodes that are METHOD_PARAMETER_IN - */ - def isParameter: Iterator[MethodParameterIn] = - traversal.collectAll[MethodParameterIn] - - /** Traverse only to AST nodes that are TemplateDom nodes - */ - def isTemplateDom: Iterator[TemplateDom] = - traversal.collectAll[TemplateDom] - - /** Traverse only to AST nodes that are TYPE_DECL - */ - def isTypeDecl: Iterator[TypeDecl] = - traversal.collectAll[TypeDecl] - - def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = - traversal - .repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst.until(_.hasLabel(labels: _*))) - .dedup - .cast[StoredNode] - -} +class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: + + /** Nodes of the AST rooted in this node, including the node itself. + */ + @Doc(info = "All nodes of the abstract syntax tree") + def ast: Iterator[AstNode] = + traversal.repeat(_.out(EdgeTypes.AST))(_.emit).cast[AstNode] + + /** All nodes of the abstract syntax tree rooted in this node, which match `predicate`. + * Equivalent of `match` in the original CPG paper. + */ + def ast(predicate: AstNode => Boolean): Iterator[AstNode] = + ast.filter(predicate) + + def containsCallTo(regex: String): Iterator[A] = + traversal.filter(_.ast.isCall.name(regex).nonEmpty) + + @Doc(info = "Depth of the abstract syntax tree") + def depth: Iterator[Int] = + traversal.map(_.depth) + + def depth(p: AstNode => Boolean): Iterator[Int] = + traversal.map(_.depth(p)) + + def isCallTo(regex: String): Iterator[Call] = + isCall.name(regex) + + /** Nodes of the AST rooted in this node, minus the node itself + */ + def astMinusRoot: Iterator[AstNode] = + traversal.repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst).cast[AstNode] + + /** Direct children of node in the AST. Siblings are ordered by their `order` fields + */ + def astChildren: Iterator[AstNode] = + traversal.flatMap(_.astChildren).sortBy(_.order).iterator + + /** Parent AST node + */ + def astParent: Iterator[AstNode] = + traversal.in(EdgeTypes.AST).cast[AstNode] + + /** Siblings of this node in the AST, ordered by their `order` fields + */ + def astSiblings: Iterator[AstNode] = + traversal.flatMap(_.astSiblings) + + /** Traverses up the AST and returns the first block node. + */ + def parentBlock: Iterator[Block] = + traversal.repeat(_.in(EdgeTypes.AST))(_.emit.until(_.hasLabel(NodeTypes.BLOCK))).collectAll[ + Block + ] + + /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached + */ + def inAst: Iterator[AstNode] = + inAst(null) + + /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached, + * minus this node + */ + def inAstMinusLeaf: Iterator[AstNode] = + inAstMinusLeaf(null) + + /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root + * is reached + */ + def inAst(root: AstNode): Iterator[AstNode] = + traversal + .repeat(_.in(EdgeTypes.AST))( + _.emit + .until(_.or( + _.hasLabel(NodeTypes.METHOD), + _.filter(n => root != null && root == n) + )) + ) + .cast[AstNode] + + /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root + * is reached, minus this node + */ + def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = + traversal + .repeat(_.in(EdgeTypes.AST))( + _.emitAllButFirst + .until(_.or( + _.hasLabel(NodeTypes.METHOD), + _.filter(n => root != null && root == n) + )) + ) + .cast[AstNode] + + /** Traverse only to those AST nodes that are also control flow graph nodes + */ + def isCfgNode: Iterator[CfgNode] = + traversal.collectAll[CfgNode] + + def isAnnotation: Iterator[Annotation] = + traversal.collectAll[Annotation] + + def isAnnotationLiteral: Iterator[AnnotationLiteral] = + traversal.collectAll[AnnotationLiteral] + + def isArrayInitializer: Iterator[ArrayInitializer] = + traversal.collectAll[ArrayInitializer] + + /** Traverse only to those AST nodes that are blocks + */ + def isBlock: Iterator[Block] = + traversal.collectAll[Block] + + /** Traverse only to those AST nodes that are control structures + */ + def isControlStructure: Iterator[ControlStructure] = + traversal.collectAll[ControlStructure] + + /** Traverse only to AST nodes that are expressions + */ + def isExpression: Iterator[Expression] = + traversal.collectAll[Expression] + + /** Traverse only to AST nodes that are calls + */ + def isCall: Iterator[Call] = + traversal.collectAll[Call] + + /** Cast to call if applicable and filter on call code `calleeRegex` + */ + def isCall(calleeRegex: String): Iterator[Call] = + isCall.where(_.code(calleeRegex)) + + /** Traverse only to AST nodes that are literals + */ + def isLiteral: Iterator[Literal] = + traversal.collectAll[Literal] + + def isLocal: Iterator[Local] = + traversal.collectAll[Local] + + /** Traverse only to AST nodes that are identifier + */ + def isIdentifier: Iterator[Identifier] = + traversal.collectAll[Identifier] + + /** Traverse only to AST nodes that are IMPORT nodes + */ + def isImport: Iterator[Import] = + traversal.collectAll[Import] + + /** Traverse only to FILE AST nodes + */ + def isFile: Iterator[File] = + traversal.collectAll[File] + + /** Traverse only to AST nodes that are field identifier + */ + def isFieldIdentifier: Iterator[FieldIdentifier] = + traversal.collectAll[FieldIdentifier] + + /** Traverse only to AST nodes that are return nodes + */ + def isReturn: Iterator[Return] = + traversal.collectAll[Return] + + /** Traverse only to AST nodes that are MEMBER + */ + def isMember: Iterator[Member] = + traversal.collectAll[Member] + + /** Traverse only to AST nodes that are method reference + */ + def isMethodRef: Iterator[MethodRef] = + traversal.collectAll[MethodRef] + + /** Traverse only to AST nodes that are type reference + */ + def isTypeRef: Iterator[TypeRef] = + traversal.collectAll[TypeRef] + + /** Traverse only to AST nodes that are METHOD + */ + def isMethod: Iterator[Method] = + traversal.collectAll[Method] + + /** Traverse only to AST nodes that are MODIFIER + */ + def isModifier: Iterator[Modifier] = + traversal.collectAll[Modifier] + + /** Traverse only to AST nodes that are NAMESPACE_BLOCK + */ + def isNamespaceBlock: Iterator[NamespaceBlock] = + traversal.collectAll[NamespaceBlock] + + /** Traverse only to AST nodes that are METHOD_PARAMETER_IN + */ + def isParameter: Iterator[MethodParameterIn] = + traversal.collectAll[MethodParameterIn] + + /** Traverse only to AST nodes that are TemplateDom nodes + */ + def isTemplateDom: Iterator[TemplateDom] = + traversal.collectAll[TemplateDom] + + /** Traverse only to AST nodes that are TYPE_DECL + */ + def isTypeDecl: Iterator[TypeDecl] = + traversal.collectAll[TypeDecl] + + def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = + traversal + .repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst.until(_.hasLabel(labels*))) + .dedup + .cast[StoredNode] +end AstNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala index 3ac6a86b..230ffea3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala @@ -1,101 +1,101 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[CfgNode]) -class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal { - - /** Textual representation of CFG node - */ - @Doc(info = "Textual representation of CFG node") - def repr: Iterator[String] = - traversal.map(_.repr) - - /** Traverse to enclosing method - */ - def method: Iterator[Method] = - traversal.map(_.method) // refers to `semanticcpg.language.nodemethods.CfgNodeMethods.method` - - /** Traverse to next expression in CFG. - */ - - @Doc(info = "Nodes directly reachable via outgoing CFG edges") - def cfgNext: Iterator[CfgNode] = - traversal._cfgOut - .filterNot(_.isInstanceOf[MethodReturn]) - .cast[CfgNode] - - /** Traverse to previous expression in CFG. - */ - @Doc(info = "Nodes directly reachable via incoming CFG edges") - def cfgPrev: Iterator[CfgNode] = - traversal._cfgIn - .filterNot(_.isInstanceOf[MethodReturn]) - .cast[CfgNode] - - /** All nodes reachable in the CFG by up to n forward expansions - */ - def cfgNext(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgNext(n)) - - /** All nodes reachable in the CFG by up to n backward expansions - */ - def cfgPrev(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgPrev(n)) - - /** Recursively determine all nodes on which any of the nodes in this traversal are control dependent - */ - @Doc(info = "All nodes on which this node is control dependent") - def controlledBy: Iterator[CfgNode] = - traversal.flatMap(_.controlledBy) - - /** Recursively determine all nodes which are control dependent on this node - */ - @Doc(info = "All nodes control dependent on this node") - def controls: Iterator[CfgNode] = - traversal.flatMap(_.controls) - - /** Recursively determine all nodes by which this node is dominated - */ - @Doc(info = "All nodes by which this node is dominated") - def dominatedBy: Iterator[CfgNode] = - traversal.flatMap(_.dominatedBy) - - /** Recursively determine all nodes which this node dominates - */ - @Doc(info = "All nodes that are dominated by this node") - def dominates: Iterator[CfgNode] = - traversal.flatMap(_.dominates) - - /** Recursively determine all nodes by which this node is post dominated - */ - @Doc(info = "All nodes by which this node is post dominated") - def postDominatedBy: Iterator[CfgNode] = - traversal.flatMap(_.postDominatedBy) - - /** Recursively determine all nodes which this node post dominates - */ - @Doc(info = "All nodes that are post dominated by this node") - def postDominates: Iterator[CfgNode] = - traversal.flatMap(_.postDominates) - - /** Obtain hexadecimal string representation of lineNumber field. - */ - @Doc(info = "Address of the code (for binary code)") - def address: Iterator[Option[String]] = - traversal.map(_.address) - - @Doc(info = "Filters in paths that pass though the given traversal") - def passes(included: Iterator[CfgNode]): Iterator[CfgNode] = { - val in = included.toSet - traversal.flatMap(_.passes(in)) - } - - @Doc(info = "Filters out paths that pass though the given traversal") - def passesNot(excluded: Iterator[CfgNode]): Iterator[CfgNode] = { - val ex = excluded.toSet - traversal.flatMap(_.passesNot(ex)) - } - -} +class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal: + + /** Textual representation of CFG node + */ + @Doc(info = "Textual representation of CFG node") + def repr: Iterator[String] = + traversal.map(_.repr) + + /** Traverse to enclosing method + */ + def method: Iterator[Method] = + traversal.map( + _.method + ) // refers to `semanticcpg.language.nodemethods.CfgNodeMethods.method` + + /** Traverse to next expression in CFG. + */ + + @Doc(info = "Nodes directly reachable via outgoing CFG edges") + def cfgNext: Iterator[CfgNode] = + traversal._cfgOut + .filterNot(_.isInstanceOf[MethodReturn]) + .cast[CfgNode] + + /** Traverse to previous expression in CFG. + */ + @Doc(info = "Nodes directly reachable via incoming CFG edges") + def cfgPrev: Iterator[CfgNode] = + traversal._cfgIn + .filterNot(_.isInstanceOf[MethodReturn]) + .cast[CfgNode] + + /** All nodes reachable in the CFG by up to n forward expansions + */ + def cfgNext(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgNext(n)) + + /** All nodes reachable in the CFG by up to n backward expansions + */ + def cfgPrev(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgPrev(n)) + + /** Recursively determine all nodes on which any of the nodes in this traversal are control + * dependent + */ + @Doc(info = "All nodes on which this node is control dependent") + def controlledBy: Iterator[CfgNode] = + traversal.flatMap(_.controlledBy) + + /** Recursively determine all nodes which are control dependent on this node + */ + @Doc(info = "All nodes control dependent on this node") + def controls: Iterator[CfgNode] = + traversal.flatMap(_.controls) + + /** Recursively determine all nodes by which this node is dominated + */ + @Doc(info = "All nodes by which this node is dominated") + def dominatedBy: Iterator[CfgNode] = + traversal.flatMap(_.dominatedBy) + + /** Recursively determine all nodes which this node dominates + */ + @Doc(info = "All nodes that are dominated by this node") + def dominates: Iterator[CfgNode] = + traversal.flatMap(_.dominates) + + /** Recursively determine all nodes by which this node is post dominated + */ + @Doc(info = "All nodes by which this node is post dominated") + def postDominatedBy: Iterator[CfgNode] = + traversal.flatMap(_.postDominatedBy) + + /** Recursively determine all nodes which this node post dominates + */ + @Doc(info = "All nodes that are post dominated by this node") + def postDominates: Iterator[CfgNode] = + traversal.flatMap(_.postDominates) + + /** Obtain hexadecimal string representation of lineNumber field. + */ + @Doc(info = "Address of the code (for binary code)") + def address: Iterator[Option[String]] = + traversal.map(_.address) + + @Doc(info = "Filters in paths that pass though the given traversal") + def passes(included: Iterator[CfgNode]): Iterator[CfgNode] = + val in = included.toSet + traversal.flatMap(_.passes(in)) + + @Doc(info = "Filters out paths that pass though the given traversal") + def passesNot(excluded: Iterator[CfgNode]): Iterator[CfgNode] = + val ex = excluded.toSet + traversal.flatMap(_.passesNot(ex)) +end CfgNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala index 9cd466c2..c4018300 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala @@ -1,24 +1,31 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations -import io.shiftleft.codepropertygraph.generated.nodes.{ClosureBinding, Declaration, MethodRef, TypeRef} +import io.shiftleft.codepropertygraph.generated.nodes.{ + ClosureBinding, + Declaration, + MethodRef, + TypeRef +} import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help /** A declaration, such as a local or parameter. */ @help.Traversal(elementType = classOf[Declaration]) -class DeclarationTraversal[NodeType <: Declaration](val traversal: Iterator[NodeType]) extends AnyVal { - - /** The closure binding node referenced by this declaration - */ - def closureBinding: Iterator[ClosureBinding] = traversal.flatMap(_._refIn).collectAll[ClosureBinding] +class DeclarationTraversal[NodeType <: Declaration](val traversal: Iterator[NodeType]) + extends AnyVal: - /** Methods that capture this declaration - */ - def capturedByMethodRef: Iterator[MethodRef] = closureBinding.flatMap(_._captureIn).collectAll[MethodRef] + /** The closure binding node referenced by this declaration + */ + def closureBinding: Iterator[ClosureBinding] = + traversal.flatMap(_._refIn).collectAll[ClosureBinding] - /** Types that capture this declaration - */ - def capturedByTypeRef: Iterator[TypeRef] = closureBinding.flatMap(_._captureIn).collectAll[TypeRef] + /** Methods that capture this declaration + */ + def capturedByMethodRef: Iterator[MethodRef] = + closureBinding.flatMap(_._captureIn).collectAll[MethodRef] -} + /** Types that capture this declaration + */ + def capturedByTypeRef: Iterator[TypeRef] = + closureBinding.flatMap(_._captureIn).collectAll[TypeRef] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala index 8a44e947..a297ad14 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala @@ -1,73 +1,74 @@ package io.shiftleft.semanticcpg.language.types.expressions.generalizations -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* /** An expression (base type) */ -class ExpressionTraversal[NodeType <: Expression](val traversal: Iterator[NodeType]) extends AnyVal { +class ExpressionTraversal[NodeType <: Expression](val traversal: Iterator[NodeType]) extends AnyVal: - /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's continuing it's walk - * until it hits an expression that's not a generic "member access operation", e.g., ".memberAccess". - */ - def parentExpression: Iterator[Expression] = - traversal.flatMap(_.parentExpression) + /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's + * continuing it's walk until it hits an expression that's not a generic "member access + * operation", e.g., ".memberAccess". + */ + def parentExpression: Iterator[Expression] = + traversal.flatMap(_.parentExpression) - /** Traverse to enclosing expression - */ - def expressionUp: Iterator[Expression] = - traversal.flatMap(_.expressionUp) + /** Traverse to enclosing expression + */ + def expressionUp: Iterator[Expression] = + traversal.flatMap(_.expressionUp) - /** Traverse to sub expressions - */ - def expressionDown: Iterator[Expression] = - traversal.flatMap(_.expressionDown) + /** Traverse to sub expressions + */ + def expressionDown: Iterator[Expression] = + traversal.flatMap(_.expressionDown) - /** If the expression is used as receiver for a call, this traverses to the call. - */ - def receivedCall: Iterator[Call] = - traversal.flatMap(_.receivedCall) + /** If the expression is used as receiver for a call, this traverses to the call. + */ + def receivedCall: Iterator[Call] = + traversal.flatMap(_.receivedCall) - /** Only those expressions which are (direct) arguments of a call - */ - def isArgument: Iterator[Expression] = - traversal.flatMap(_.isArgument) + /** Only those expressions which are (direct) arguments of a call + */ + def isArgument: Iterator[Expression] = + traversal.flatMap(_.isArgument) - /** Traverse to surrounding call - */ - def inCall: Iterator[Call] = - traversal.flatMap(_.inCall) + /** Traverse to surrounding call + */ + def inCall: Iterator[Call] = + traversal.flatMap(_.inCall) - /** Traverse to surrounding call - */ - @deprecated("Use inCall") - def call: Iterator[Call] = - inCall + /** Traverse to surrounding call + */ + @deprecated("Use inCall") + def call: Iterator[Call] = + inCall - /** Traverse to related parameter - */ - @deprecated("", "October 2019") - def toParameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = parameter + /** Traverse to related parameter + */ + @deprecated("", "October 2019") + def toParameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = parameter - /** Traverse to related parameter, if the expression is an argument to a call and the call can be resolved. - */ - def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = - traversal.flatMap(_.parameter) + /** Traverse to related parameter, if the expression is an argument to a call and the call can + * be resolved. + */ + def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = + traversal.flatMap(_.parameter) - /** Traverse to enclosing method - */ - def method: Iterator[Method] = - traversal._containsIn - .flatMap { - case x: Method => x.start - case x: TypeDecl => x.astParent - } - .collectAll[Method] + /** Traverse to enclosing method + */ + def method: Iterator[Method] = + traversal._containsIn + .flatMap { + case x: Method => x.start + case x: TypeDecl => x.astParent + } + .collectAll[Method] - /** Traverse to expression evaluation type - */ - def typ: Iterator[Type] = - traversal.flatMap(_._evalTypeOut).cast[Type] - -} + /** Traverse to expression evaluation type + */ + def typ: Iterator[Type] = + traversal.flatMap(_._evalTypeOut).cast[Type] +end ExpressionTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala index 27af5cc6..c3133f55 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala @@ -4,43 +4,38 @@ import io.shiftleft.codepropertygraph.generated.Properties import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.semanticcpg.language.* -class EvalTypeAccessors[A <: AstNode](val traversal: Iterator[A]) extends AnyVal { - - def evalType: Iterator[String] = - evalType(traversal) - - def evalType(regex: String): Iterator[A] = { - traversal.where(evalType(_).filter(_.matches(regex))) - } - - def evalType(regexes: String*): Iterator[A] = - if (regexes.isEmpty) Iterator.empty - else { - val regexes0 = regexes.map(_.r).toSet - traversal.where(evalType(_).filter(value => regexes0.exists(_.matches(value)))) - } - - def evalTypeExact(value: String): Iterator[A] = - traversal.where(evalType(_).filter(_ == value)) - - def evalTypeExact(values: String*): Iterator[A] = - if (values.isEmpty) Iterator.empty - else { - val valuesSet = values.to(Set) - traversal.where(evalType(_).filter(valuesSet.contains)) - } - - def evalTypeNot(regex: String): Iterator[A] = - traversal.where(evalType(_).filterNot(_.matches(regex))) - - def evalTypeNot(regexes: String*): Iterator[A] = - if (regexes.isEmpty) Iterator.empty - else { - val regexes0 = regexes.map(_.r).toSet - traversal.where(evalType(_).filter(value => !regexes0.exists(_.matches(value)))) - } - - private def evalType(traversal: Iterator[A]): Iterator[String] = - traversal.flatMap(_._evalTypeOut).flatMap(_._refOut).property(Properties.FULL_NAME) - -} +class EvalTypeAccessors[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: + + def evalType: Iterator[String] = + evalType(traversal) + + def evalType(regex: String): Iterator[A] = + traversal.where(evalType(_).filter(_.matches(regex))) + + def evalType(regexes: String*): Iterator[A] = + if regexes.isEmpty then Iterator.empty + else + val regexes0 = regexes.map(_.r).toSet + traversal.where(evalType(_).filter(value => regexes0.exists(_.matches(value)))) + + def evalTypeExact(value: String): Iterator[A] = + traversal.where(evalType(_).filter(_ == value)) + + def evalTypeExact(values: String*): Iterator[A] = + if values.isEmpty then Iterator.empty + else + val valuesSet = values.to(Set) + traversal.where(evalType(_).filter(valuesSet.contains)) + + def evalTypeNot(regex: String): Iterator[A] = + traversal.where(evalType(_).filterNot(_.matches(regex))) + + def evalTypeNot(regexes: String*): Iterator[A] = + if regexes.isEmpty then Iterator.empty + else + val regexes0 = regexes.map(_.r).toSet + traversal.where(evalType(_).filter(value => !regexes0.exists(_.matches(value)))) + + private def evalType(traversal: Iterator[A]): Iterator[String] = + traversal.flatMap(_._evalTypeOut).flatMap(_._refOut).property(Properties.FULL_NAME) +end EvalTypeAccessors diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala index 0b474e56..fa3d378a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala @@ -6,45 +6,45 @@ import io.shiftleft.codepropertygraph.generated.traversal.toModifierTraversalExt import io.shiftleft.semanticcpg.language.* import overflowdb.* -class ModifierAccessors[A <: Node](val traversal: Iterator[A]) extends AnyVal { +class ModifierAccessors[A <: Node](val traversal: Iterator[A]) extends AnyVal: - /** Filter: only `public` nodes */ - def isPublic: Iterator[A] = - hasModifier(ModifierTypes.PUBLIC) + /** Filter: only `public` nodes */ + def isPublic: Iterator[A] = + hasModifier(ModifierTypes.PUBLIC) - /** Filter: only `private` nodes */ - def isPrivate: Iterator[A] = - hasModifier(ModifierTypes.PRIVATE) + /** Filter: only `private` nodes */ + def isPrivate: Iterator[A] = + hasModifier(ModifierTypes.PRIVATE) - /** Filter: only `protected` nodes */ - def isProtected: Iterator[A] = - hasModifier(ModifierTypes.PROTECTED) + /** Filter: only `protected` nodes */ + def isProtected: Iterator[A] = + hasModifier(ModifierTypes.PROTECTED) - /** Filter: only `abstract` nodes */ - def isAbstract: Iterator[A] = - hasModifier(ModifierTypes.ABSTRACT) + /** Filter: only `abstract` nodes */ + def isAbstract: Iterator[A] = + hasModifier(ModifierTypes.ABSTRACT) - /** Filter: only `static` nodes */ - def isStatic: Iterator[A] = - hasModifier(ModifierTypes.STATIC) + /** Filter: only `static` nodes */ + def isStatic: Iterator[A] = + hasModifier(ModifierTypes.STATIC) - /** Filter: only `native` nodes */ - def isNative: Iterator[A] = - hasModifier(ModifierTypes.NATIVE) + /** Filter: only `native` nodes */ + def isNative: Iterator[A] = + hasModifier(ModifierTypes.NATIVE) - /** Filter: only `constructor` nodes */ - def isConstructor: Iterator[A] = - hasModifier(ModifierTypes.CONSTRUCTOR) + /** Filter: only `constructor` nodes */ + def isConstructor: Iterator[A] = + hasModifier(ModifierTypes.CONSTRUCTOR) - /** Filter: only `virtual` nodes */ - def isVirtual: Iterator[A] = - hasModifier(ModifierTypes.VIRTUAL) + /** Filter: only `virtual` nodes */ + def isVirtual: Iterator[A] = + hasModifier(ModifierTypes.VIRTUAL) - def hasModifier(modifier: String): Iterator[A] = - traversal.where(_.out.collectAll[Modifier].modifierType(modifier)) + def hasModifier(modifier: String): Iterator[A] = + traversal.where(_.out.collectAll[Modifier].modifierType(modifier)) - /** Traverse to modifiers, e.g., "static", "public". - */ - def modifier: Iterator[Modifier] = - traversal.out.collectAll[Modifier] -} + /** Traverse to modifiers, e.g., "static", "public". + */ + def modifier: Iterator[Modifier] = + traversal.out.collectAll[Modifier] +end ModifierAccessors diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala index 684571b0..00dd5e87 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala @@ -5,18 +5,18 @@ import io.shiftleft.semanticcpg.language.* /** An annotation parameter-assignment, e.g., `foo=value` in @Test(foo=value) */ -class AnnotationParameterAssignTraversal(val traversal: Iterator[AnnotationParameterAssign]) extends AnyVal { +class AnnotationParameterAssignTraversal(val traversal: Iterator[AnnotationParameterAssign]) + extends AnyVal: - /** Traverse to all annotation parameters - */ - def parameter: Iterator[AnnotationParameter] = - traversal.flatMap(_._annotationParameterViaAstOut) + /** Traverse to all annotation parameters + */ + def parameter: Iterator[AnnotationParameter] = + traversal.flatMap(_._annotationParameterViaAstOut) - /** Traverse to all values of annotation parameters - */ - def value: Iterator[Expression] = - traversal - .flatMap(_.astOut) - .filterNot(_.isInstanceOf[AnnotationParameter]) - .cast[Expression] -} + /** Traverse to all values of annotation parameters + */ + def value: Iterator[Expression] = + traversal + .flatMap(_.astOut) + .filterNot(_.isInstanceOf[AnnotationParameter]) + .cast[Expression] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala index a597e668..3df4f8c0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala @@ -1,34 +1,34 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes -import overflowdb.traversal._ +import overflowdb.traversal.* /** An (Java-) annotation, e.g., @Test. */ -class AnnotationTraversal(val traversal: Iterator[nodes.Annotation]) extends AnyVal { +class AnnotationTraversal(val traversal: Iterator[nodes.Annotation]) extends AnyVal: - /** Traverse to parameter assignments - */ - def parameterAssign: Iterator[nodes.AnnotationParameterAssign] = - traversal.flatMap(_._annotationParameterAssignViaAstOut) + /** Traverse to parameter assignments + */ + def parameterAssign: Iterator[nodes.AnnotationParameterAssign] = + traversal.flatMap(_._annotationParameterAssignViaAstOut) - /** Traverse to methods annotated with this annotation. - */ - def method: Iterator[nodes.Method] = - traversal.flatMap(_._methodViaAstIn) + /** Traverse to methods annotated with this annotation. + */ + def method: Iterator[nodes.Method] = + traversal.flatMap(_._methodViaAstIn) - /** Traverse to type declarations annotated by this annotation - */ - def typeDecl: Iterator[nodes.TypeDecl] = - traversal.flatMap(_._typeDeclViaAstIn) + /** Traverse to type declarations annotated by this annotation + */ + def typeDecl: Iterator[nodes.TypeDecl] = + traversal.flatMap(_._typeDeclViaAstIn) - /** Traverse to member annotated by this annotation - */ - def member: Iterator[nodes.Member] = - traversal.flatMap(_._memberViaAstIn) + /** Traverse to member annotated by this annotation + */ + def member: Iterator[nodes.Member] = + traversal.flatMap(_._memberViaAstIn) - /** Traverse to parameter annotated by this annotation - */ - def parameter: Iterator[nodes.MethodParameterIn] = - traversal.flatMap(_._methodParameterInViaAstIn) -} + /** Traverse to parameter annotated by this annotation + */ + def parameter: Iterator[nodes.MethodParameterIn] = + traversal.flatMap(_._methodParameterInViaAstIn) +end AnnotationTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala index 36309ff4..82339ab6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala @@ -4,6 +4,5 @@ import io.shiftleft.codepropertygraph.generated.nodes.Import import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.semanticcpg.language.* -class DependencyTraversal(val traversal: Iterator[nodes.Dependency]) extends AnyVal { - def imports: Iterator[Import] = traversal.in(EdgeTypes.IMPORTS).cast[Import] -} +class DependencyTraversal(val traversal: Iterator[nodes.Dependency]) extends AnyVal: + def imports: Iterator[Import] = traversal.in(EdgeTypes.IMPORTS).cast[Import] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala index c9aec567..99345de9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala @@ -1,17 +1,14 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* /** A compilation unit */ -class FileTraversal(val traversal: Iterator[File]) extends AnyVal { +class FileTraversal(val traversal: Iterator[File]) extends AnyVal: - def namespace: Iterator[Namespace] = - traversal.flatMap(_.namespaceBlock).flatMap(_._namespaceViaRefOut) + def namespace: Iterator[Namespace] = + traversal.flatMap(_.namespaceBlock).flatMap(_._namespaceViaRefOut) -} - -object FileTraversal { - val UNKNOWN = "" -} +object FileTraversal: + val UNKNOWN = "" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala index e5131658..116e104d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala @@ -3,13 +3,11 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes.{Call, Import, NamespaceBlock} import io.shiftleft.semanticcpg.language.* -class ImportTraversal(val traversal: Iterator[Import]) extends AnyVal { +class ImportTraversal(val traversal: Iterator[Import]) extends AnyVal: - /** Traverse to the call that represents the import in the AST - */ - def call: Iterator[Call] = traversal._isCallForImportIn.cast[Call] + /** Traverse to the call that represents the import in the AST + */ + def call: Iterator[Call] = traversal._isCallForImportIn.cast[Call] - def namespaceBlock: Iterator[NamespaceBlock] = - call.method.namespaceBlock - -} + def namespaceBlock: Iterator[NamespaceBlock] = + call.method.namespaceBlock diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala index 26c73041..b7cde45c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala @@ -1,20 +1,17 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeTypes} import io.shiftleft.semanticcpg.language.* /** A local variable */ -class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal { +class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal: - /** The method hosting this local variable - */ - def method: Iterator[Method] = { - // TODO The following line of code is here for backwards compatibility. - // Use the lower commented out line once not required anymore. - traversal.repeat(_.in(EdgeTypes.AST))(_.until(_.hasLabel(NodeTypes.METHOD))).cast[Method] - // definingBlock.method - } - -} + /** The method hosting this local variable + */ + def method: Iterator[Method] = + // TODO The following line of code is here for backwards compatibility. + // Use the lower commented out line once not required anymore. + traversal.repeat(_.in(EdgeTypes.AST))(_.until(_.hasLabel(NodeTypes.METHOD))).cast[Method] + // definingBlock.method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala index d83a7062..e63310aa 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala @@ -1,21 +1,19 @@ package io.shiftleft.semanticcpg.language.types.structure -import io.shiftleft.codepropertygraph.generated._ +import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.{Call, Member} import io.shiftleft.semanticcpg.language.* /** A member variable of a class/type. */ -class MemberTraversal(val traversal: Iterator[Member]) extends AnyVal { +class MemberTraversal(val traversal: Iterator[Member]) extends AnyVal: - /** Traverse to annotations of member - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) + /** Traverse to annotations of member + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) - /** Places where - */ - def ref: Iterator[Call] = - traversal.flatMap(_._callViaRefIn) - -} + /** Places where + */ + def ref: Iterator[Call] = + traversal.flatMap(_._callViaRefIn) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala index abb9778a..ea21c91a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala @@ -5,31 +5,30 @@ import io.shiftleft.semanticcpg.language.* import scala.jdk.CollectionConverters.* -class MethodParameterOutTraversal(val traversal: Iterator[MethodParameterOut]) extends AnyVal { - - def paramIn: Iterator[MethodParameterIn] = traversal.flatMap(_.parameterLinkIn.headOption) - - /* method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def index(num: Int): Iterator[MethodParameterOut] = - traversal.filter { _.index == num } - - /* get all parameters from (and including) - * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def indexFrom(num: Int): Iterator[MethodParameterOut] = - traversal.filter(_.index >= num) - - /* get all parameters up to (and including) - * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def indexTo(num: Int): Iterator[MethodParameterOut] = - traversal.filter(_.index <= num) - - def argument: Iterator[Expression] = - for { - paramOut <- traversal - method = paramOut.method - call <- method.callIn - arg <- call.argumentOut.collectAll[Expression] - if paramOut.parameterLinkIn.index.headOption.contains(arg.argumentIndex) - } yield arg - -} +class MethodParameterOutTraversal(val traversal: Iterator[MethodParameterOut]) extends AnyVal: + + def paramIn: Iterator[MethodParameterIn] = traversal.flatMap(_.parameterLinkIn.headOption) + + /* method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def index(num: Int): Iterator[MethodParameterOut] = + traversal.filter { _.index == num } + + /* get all parameters from (and including) + * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def indexFrom(num: Int): Iterator[MethodParameterOut] = + traversal.filter(_.index >= num) + + /* get all parameters up to (and including) + * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def indexTo(num: Int): Iterator[MethodParameterOut] = + traversal.filter(_.index <= num) + + def argument: Iterator[Expression] = + for + paramOut <- traversal + method = paramOut.method + call <- method.callIn + arg <- call.argumentOut.collectAll[Expression] + if paramOut.parameterLinkIn.index.headOption.contains(arg.argumentIndex) + yield arg +end MethodParameterOutTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala index 2f2af5ea..e1a90d96 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala @@ -9,31 +9,30 @@ import scala.jdk.CollectionConverters.* /** Formal method input parameter */ @help.Traversal(elementType = classOf[MethodParameterIn]) -class MethodParameterTraversal(val traversal: Iterator[MethodParameterIn]) extends AnyVal { - - /** Traverse to parameter annotations - */ - def annotation: Iterator[Annotation] = - traversal.flatMap(_._annotationViaAstOut) - - /** Traverse to all parameters with index greater or equal than `num` - */ - def indexFrom(num: Int): Iterator[MethodParameterIn] = - traversal.filter(_.index >= num) - - /** Traverse to all parameters with index smaller or equal than `num` - */ - def indexTo(num: Int): Iterator[MethodParameterIn] = - traversal.filter(_.index <= num) - - /** Traverse to arguments (actual parameters) associated with this formal parameter - */ - def argument(implicit callResolver: ICallResolver): Iterator[Expression] = - for { - paramIn <- traversal - call <- callResolver.getMethodCallsites(paramIn.method) - arg <- call._argumentOut.collectAll[Expression] - if arg.argumentIndex == paramIn.index - } yield arg - -} +class MethodParameterTraversal(val traversal: Iterator[MethodParameterIn]) extends AnyVal: + + /** Traverse to parameter annotations + */ + def annotation: Iterator[Annotation] = + traversal.flatMap(_._annotationViaAstOut) + + /** Traverse to all parameters with index greater or equal than `num` + */ + def indexFrom(num: Int): Iterator[MethodParameterIn] = + traversal.filter(_.index >= num) + + /** Traverse to all parameters with index smaller or equal than `num` + */ + def indexTo(num: Int): Iterator[MethodParameterIn] = + traversal.filter(_.index <= num) + + /** Traverse to arguments (actual parameters) associated with this formal parameter + */ + def argument(implicit callResolver: ICallResolver): Iterator[Expression] = + for + paramIn <- traversal + call <- callResolver.getMethodCallsites(paramIn.method) + arg <- call._argumentOut.collectAll[Expression] + if arg.argumentIndex == paramIn.index + yield arg +end MethodParameterTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala index a96b5fb9..90f1c4e5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala @@ -6,24 +6,24 @@ import overflowdb.traversal.help import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[MethodReturn]) -class MethodReturnTraversal(val traversal: Iterator[MethodReturn]) extends AnyVal { +class MethodReturnTraversal(val traversal: Iterator[MethodReturn]) extends AnyVal: - @Doc(info = "traverse to parent method") - def method: Iterator[Method] = - traversal.flatMap(_._methodViaAstIn) + @Doc(info = "traverse to parent method") + def method: Iterator[Method] = + traversal.flatMap(_._methodViaAstIn) - def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = - traversal.flatMap(_.returnUser) + def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = + traversal.flatMap(_.returnUser) - /** Traverse to last expressions in CFG. Can be multiple. - */ - @Doc(info = "traverse to last expressions in CFG (can be multiple)") - def cfgLast: Iterator[CfgNode] = - traversal.flatMap(_.cfgIn) + /** Traverse to last expressions in CFG. Can be multiple. + */ + @Doc(info = "traverse to last expressions in CFG (can be multiple)") + def cfgLast: Iterator[CfgNode] = + traversal.flatMap(_.cfgIn) - /** Traverse to return type - */ - @Doc(info = "traverse to return type") - def typ: Iterator[Type] = - traversal.flatMap(_.evalTypeOut) -} + /** Traverse to return type + */ + @Doc(info = "traverse to return type") + def typ: Iterator[Type] = + traversal.flatMap(_.evalTypeOut) +end MethodReturnTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala index 4ec6fa12..a16b576b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala @@ -13,215 +13,212 @@ import overflowdb.formats.graphml.GraphMLExporter import scala.jdk.CollectionConverters.* import java.nio.file.{Files, Path, Paths} -case class MethodSubGraph(methodName: String, methodFullName: String, nodes: Set[Node]) { - def edges: Set[Edge] = { - for { - node <- nodes - edge <- node.bothE.asScala - if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) - } yield edge - } -} - -def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = { - ExportResult( - nodeCount = resultA.nodeCount + resultB.nodeCount, - edgeCount = resultA.edgeCount + resultB.edgeCount, - files = resultA.files ++ resultB.files, - additionalInfo = resultA.additionalInfo - ) -} +case class MethodSubGraph(methodName: String, methodFullName: String, nodes: Set[Node]): + def edges: Set[Edge] = + for + node <- nodes + edge <- node.bothE.asScala + if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) + yield edge + +def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = + ExportResult( + nodeCount = resultA.nodeCount + resultB.nodeCount, + edgeCount = resultA.edgeCount + resultB.edgeCount, + files = resultA.files ++ resultB.files, + additionalInfo = resultA.additionalInfo + ) /** A method, function, or procedure */ @help.Traversal(elementType = classOf[Method]) -class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal { - - /** Traverse to annotations of method - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) - - /** All control structures of this method - */ - @Doc(info = "Control structures (source frontends only)") - def controlStructure: Iterator[ControlStructure] = - traversal.ast.isControlStructure - - /** Shorthand to traverse to control structures where condition matches `regex` - */ - def controlStructure(regex: String): Iterator[ControlStructure] = - traversal.ast.isControlStructure.code(regex) - - @Doc(info = "All try blocks (`ControlStructure` nodes)") - def tryBlock: Iterator[ControlStructure] = - controlStructure.isTry - - @Doc(info = "All if blocks (`ControlStructure` nodes)") - def ifBlock: Iterator[ControlStructure] = - controlStructure.isIf - - @Doc(info = "All else blocks (`ControlStructure` nodes)") - def elseBlock: Iterator[ControlStructure] = - controlStructure.isElse - - @Doc(info = "All switch blocks (`ControlStructure` nodes)") - def switchBlock: Iterator[ControlStructure] = - controlStructure.isSwitch - - @Doc(info = "All do blocks (`ControlStructure` nodes)") - def doBlock: Iterator[ControlStructure] = - controlStructure.isDo - - @Doc(info = "All for blocks (`ControlStructure` nodes)") - def forBlock: Iterator[ControlStructure] = - controlStructure.isFor - - @Doc(info = "All while blocks (`ControlStructure` nodes)") - def whileBlock: Iterator[ControlStructure] = - controlStructure.isWhile - - @Doc(info = "All gotos (`ControlStructure` nodes)") - def goto: Iterator[ControlStructure] = - controlStructure.isGoto - - @Doc(info = "All breaks (`ControlStructure` nodes)") - def break: Iterator[ControlStructure] = - controlStructure.isBreak - - @Doc(info = "All continues (`ControlStructure` nodes)") - def continue: Iterator[ControlStructure] = - controlStructure.isContinue - - @Doc(info = "All throws (`ControlStructure` nodes)") - def throws: Iterator[ControlStructure] = - controlStructure.isThrow - - /** The type declaration associated with this method, e.g., the class it is defined in. - */ - @Doc(info = "Type this method is defined in") - def definingTypeDecl: Iterator[TypeDecl] = - traversal - .repeat(_._astIn)(_.until(_.collectAll[TypeDecl])) - .cast[TypeDecl] - - /** The type declaration associated with this method, e.g., the class it is defined in. Alias for 'definingTypeDecl' - */ - @Doc(info = "Type this method is defined in - alias for 'definingTypeDecl'") - def typeDecl: Iterator[TypeDecl] = definingTypeDecl - - /** The method in which this method is defined - */ - @Doc(info = "Method this method is defined in") - def definingMethod: Iterator[Method] = - traversal - .repeat(_._astIn)(_.until(_.collectAll[Method])) - .cast[Method] - - /** Traverse only to methods that are stubs, e.g., their code is not available or the method body is empty. - */ - def isStub: Iterator[Method] = - traversal.where(_.not(_._cfgOut.not(_.collectAll[MethodReturn]))) - - /** Traverse only to methods that are not stubs. - */ - def isNotStub: Iterator[Method] = - traversal.where(_._cfgOut.not(_.collectAll[MethodReturn])) - - /** Traverse only to methods that accept variadic arguments. - */ - def isVariadic: Iterator[Method] = { - traversal.filter(_.isVariadic) - } - - /** Traverse to external methods, that is, methods not present but only referenced in the CPG. - */ - @Doc(info = "External methods (called, but no body available)") - def external: Iterator[Method] = - traversal.isExternal(true) - - /** Traverse to internal methods, that is, methods for which code is included in this CPG. - */ - @Doc(info = "Internal methods, i.e., a body is available") - def internal: Iterator[Method] = - traversal.isExternal(false) - - /** Traverse to the methods local variables - */ - @Doc(info = "Local variables declared in the method") - def local: Iterator[Local] = - traversal.block.ast.isLocal - - @Doc(info = "Top level expressions (\"Statements\")") - def topLevelExpressions: Iterator[Expression] = - traversal._astOut - .collectAll[Block] - ._astOut - .not(_.collectAll[Local]) - .cast[Expression] - - @Doc(info = "Control flow graph nodes") - def cfgNode: Iterator[CfgNode] = - traversal.flatMap(_.cfgNode) - - /** Traverse to last expression in CFG. - */ - @Doc(info = "Last control flow graph node") - def cfgLast: Iterator[CfgNode] = - traversal.methodReturn.cfgLast - - /** Traverse to method body (alias for `block`) */ - @Doc(info = "Alias for `block`") - def body: Iterator[Block] = - traversal.block - - /** Traverse to namespace */ - @Doc(info = "Namespace this method is declared in") - def namespace: Iterator[Namespace] = { - traversal.namespaceBlock.namespace - } - - /** Traverse to namespace block */ - @Doc(info = "Namespace block this method is declared in") - def namespaceBlock: Iterator[NamespaceBlock] = { - traversal.flatMap { m => - m.astIn.headOption match { - // some language frontends don't have a TYPE_DECL for a METHOD - case Some(namespaceBlock: NamespaceBlock) => namespaceBlock.start - // other language frontends always embed their method in a TYPE_DECL - case _ => m.definingTypeDecl.namespaceBlock - } - } - } - - def numberOfLines: Iterator[Int] = traversal.map(_.numberOfLines) - - @Doc(info = "Export the methods to graphml") - def gml(gmlDir: String = null): ExportResult = { - var pathToUse = gmlDir - traversal - .map { method => - MethodSubGraph(methodName = method.name, methodFullName = method.fullName, nodes = method.ast.toSet) - } - .map { case subGraph @ MethodSubGraph(methodName, methodFullName, nodes) => - val methodHash = Fingerprinting.calculate_hash(methodFullName) - try { - if (pathToUse == null) { - pathToUse = Files.createTempDirectory("gml-export").toAbsolutePath.toString - } else { - Paths.get(pathToUse).toFile.mkdirs() - } - } catch { - case exc: Exception => +class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: + + /** Traverse to annotations of method + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) + + /** All control structures of this method + */ + @Doc(info = "Control structures (source frontends only)") + def controlStructure: Iterator[ControlStructure] = + traversal.ast.isControlStructure + + /** Shorthand to traverse to control structures where condition matches `regex` + */ + def controlStructure(regex: String): Iterator[ControlStructure] = + traversal.ast.isControlStructure.code(regex) + + @Doc(info = "All try blocks (`ControlStructure` nodes)") + def tryBlock: Iterator[ControlStructure] = + controlStructure.isTry + + @Doc(info = "All if blocks (`ControlStructure` nodes)") + def ifBlock: Iterator[ControlStructure] = + controlStructure.isIf + + @Doc(info = "All else blocks (`ControlStructure` nodes)") + def elseBlock: Iterator[ControlStructure] = + controlStructure.isElse + + @Doc(info = "All switch blocks (`ControlStructure` nodes)") + def switchBlock: Iterator[ControlStructure] = + controlStructure.isSwitch + + @Doc(info = "All do blocks (`ControlStructure` nodes)") + def doBlock: Iterator[ControlStructure] = + controlStructure.isDo + + @Doc(info = "All for blocks (`ControlStructure` nodes)") + def forBlock: Iterator[ControlStructure] = + controlStructure.isFor + + @Doc(info = "All while blocks (`ControlStructure` nodes)") + def whileBlock: Iterator[ControlStructure] = + controlStructure.isWhile + + @Doc(info = "All gotos (`ControlStructure` nodes)") + def goto: Iterator[ControlStructure] = + controlStructure.isGoto + + @Doc(info = "All breaks (`ControlStructure` nodes)") + def break: Iterator[ControlStructure] = + controlStructure.isBreak + + @Doc(info = "All continues (`ControlStructure` nodes)") + def continue: Iterator[ControlStructure] = + controlStructure.isContinue + + @Doc(info = "All throws (`ControlStructure` nodes)") + def throws: Iterator[ControlStructure] = + controlStructure.isThrow + + /** The type declaration associated with this method, e.g., the class it is defined in. + */ + @Doc(info = "Type this method is defined in") + def definingTypeDecl: Iterator[TypeDecl] = + traversal + .repeat(_._astIn)(_.until(_.collectAll[TypeDecl])) + .cast[TypeDecl] + + /** The type declaration associated with this method, e.g., the class it is defined in. Alias + * for 'definingTypeDecl' + */ + @Doc(info = "Type this method is defined in - alias for 'definingTypeDecl'") + def typeDecl: Iterator[TypeDecl] = definingTypeDecl + + /** The method in which this method is defined + */ + @Doc(info = "Method this method is defined in") + def definingMethod: Iterator[Method] = + traversal + .repeat(_._astIn)(_.until(_.collectAll[Method])) + .cast[Method] + + /** Traverse only to methods that are stubs, e.g., their code is not available or the method + * body is empty. + */ + def isStub: Iterator[Method] = + traversal.where(_.not(_._cfgOut.not(_.collectAll[MethodReturn]))) + + /** Traverse only to methods that are not stubs. + */ + def isNotStub: Iterator[Method] = + traversal.where(_._cfgOut.not(_.collectAll[MethodReturn])) + + /** Traverse only to methods that accept variadic arguments. + */ + def isVariadic: Iterator[Method] = + traversal.filter(_.isVariadic) + + /** Traverse to external methods, that is, methods not present but only referenced in the CPG. + */ + @Doc(info = "External methods (called, but no body available)") + def external: Iterator[Method] = + traversal.isExternal(true) + + /** Traverse to internal methods, that is, methods for which code is included in this CPG. + */ + @Doc(info = "Internal methods, i.e., a body is available") + def internal: Iterator[Method] = + traversal.isExternal(false) + + /** Traverse to the methods local variables + */ + @Doc(info = "Local variables declared in the method") + def local: Iterator[Local] = + traversal.block.ast.isLocal + + @Doc(info = "Top level expressions (\"Statements\")") + def topLevelExpressions: Iterator[Expression] = + traversal._astOut + .collectAll[Block] + ._astOut + .not(_.collectAll[Local]) + .cast[Expression] + + @Doc(info = "Control flow graph nodes") + def cfgNode: Iterator[CfgNode] = + traversal.flatMap(_.cfgNode) + + /** Traverse to last expression in CFG. + */ + @Doc(info = "Last control flow graph node") + def cfgLast: Iterator[CfgNode] = + traversal.methodReturn.cfgLast + + /** Traverse to method body (alias for `block`) */ + @Doc(info = "Alias for `block`") + def body: Iterator[Block] = + traversal.block + + /** Traverse to namespace */ + @Doc(info = "Namespace this method is declared in") + def namespace: Iterator[Namespace] = + traversal.namespaceBlock.namespace + + /** Traverse to namespace block */ + @Doc(info = "Namespace block this method is declared in") + def namespaceBlock: Iterator[NamespaceBlock] = + traversal.flatMap { m => + m.astIn.headOption match + // some language frontends don't have a TYPE_DECL for a METHOD + case Some(namespaceBlock: NamespaceBlock) => namespaceBlock.start + // other language frontends always embed their method in a TYPE_DECL + case _ => m.definingTypeDecl.namespaceBlock } - GraphMLExporter.runExport( - nodes, - subGraph.edges, - Paths.get(pathToUse, s"${methodName}-${methodHash.getOrElse("")}.graphml") - ) - } - .reduce(plus) - } - - def gml: ExportResult = gml(null) -} + + def numberOfLines: Iterator[Int] = traversal.map(_.numberOfLines) + + @Doc(info = "Export the methods to graphml") + def gml(gmlDir: String = null): ExportResult = + var pathToUse = gmlDir + traversal + .map { method => + MethodSubGraph( + methodName = method.name, + methodFullName = method.fullName, + nodes = method.ast.toSet + ) + } + .map { case subGraph @ MethodSubGraph(methodName, methodFullName, nodes) => + val methodHash = Fingerprinting.calculate_hash(methodFullName) + try + if pathToUse == null then + pathToUse = Files.createTempDirectory("gml-export").toAbsolutePath.toString + else + Paths.get(pathToUse).toFile.mkdirs() + catch + case exc: Exception => + GraphMLExporter.runExport( + nodes, + subGraph.edges, + Paths.get(pathToUse, s"${methodName}-${methodHash.getOrElse("")}.graphml") + ) + } + .reduce(plus) + end gml + + def gml: ExportResult = gml(null) +end MethodTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala index 7f95dee3..1cbe0397 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala @@ -3,18 +3,17 @@ package io.shiftleft.semanticcpg.language.types.structure import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -class NamespaceBlockTraversal(val traversal: Iterator[NamespaceBlock]) extends AnyVal { +class NamespaceBlockTraversal(val traversal: Iterator[NamespaceBlock]) extends AnyVal: - /** Namespaces for namespace blocks. - */ - def namespace: Iterator[Namespace] = - traversal.flatMap(_.refOut) + /** Namespaces for namespace blocks. + */ + def namespace: Iterator[Namespace] = + traversal.flatMap(_.refOut) - /** The type declarations defined in this namespace - */ - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_._typeDeclViaAstOut) + /** The type declarations defined in this namespace + */ + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_._typeDeclViaAstOut) - def method: Iterator[Method] = - traversal.flatMap(_._methodViaAstOut) -} + def method: Iterator[Method] = + traversal.flatMap(_._methodViaAstOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala index c636047e..d8376151 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala @@ -5,30 +5,28 @@ import io.shiftleft.semanticcpg.language.* /** A namespace, e.g., Java package or C# namespace */ -class NamespaceTraversal(val traversal: Iterator[Namespace]) extends AnyVal { - - /** The type declarations defined in this namespace - */ - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.refIn).flatMap(_._typeDeclViaAstOut) - - /** Methods defined in this namespace - */ - def method: Iterator[Method] = - traversal.flatMap(_.refIn).flatMap(_._methodViaAstOut) - - /** External namespaces - any namespaces which contain one or more external type. - */ - def external: Iterator[Namespace] = - traversal.where(_.typeDecl.external) - - /** Internal namespaces - any namespaces which contain one or more internal type - */ - def internal: Iterator[Namespace] = - traversal.where(_.typeDecl.internal) - -} - -object NamespaceTraversal { - val globalNamespaceName = "" -} +class NamespaceTraversal(val traversal: Iterator[Namespace]) extends AnyVal: + + /** The type declarations defined in this namespace + */ + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.refIn).flatMap(_._typeDeclViaAstOut) + + /** Methods defined in this namespace + */ + def method: Iterator[Method] = + traversal.flatMap(_.refIn).flatMap(_._methodViaAstOut) + + /** External namespaces - any namespaces which contain one or more external type. + */ + def external: Iterator[Namespace] = + traversal.where(_.typeDecl.external) + + /** Internal namespaces - any namespaces which contain one or more internal type + */ + def internal: Iterator[Namespace] = + traversal.where(_.typeDecl.internal) +end NamespaceTraversal + +object NamespaceTraversal: + val globalNamespaceName = "" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala index 981fb805..1250b5f4 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala @@ -6,109 +6,108 @@ import io.shiftleft.semanticcpg.language.* /** Type declaration - possibly a template that requires instantiation */ -class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal { - import TypeDeclTraversal.* - - /** Annotations of the type declaration - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) - - /** Types referencing to this type declaration. - */ - def referencingType: Iterator[Type] = - traversal.flatMap(_.refIn) - - /** Namespace in which this type declaration is defined - */ - def namespace: Iterator[Namespace] = - traversal.flatMap(_.namespaceBlock).namespace - - /** Methods defined as part of this type - */ - def method: Iterator[Method] = - canonicalType.flatMap(_._methodViaAstOut) - - /** Filter for type declarations contained in the analyzed code. - */ - def internal: Iterator[TypeDecl] = - canonicalType.isExternal(false) - - /** Filter for type declarations not contained in the analyzed code. - */ - def external: Iterator[TypeDecl] = - canonicalType.isExternal(true) - - /** Member variables - */ - def member: Iterator[Member] = - canonicalType.flatMap(_._memberViaAstOut) - - /** Direct base types in the inheritance graph. - */ - def baseType: Iterator[Type] = - canonicalType.flatMap(_._typeViaInheritsFromOut) - - /** Direct base type declaration. - */ - def derivedTypeDecl: Iterator[TypeDecl] = - referencingType.derivedTypeDecl - - /** Direct and transitive base type declaration. - */ - def derivedTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst) - - /** Direct base type declaration. - */ - def baseTypeDecl: Iterator[TypeDecl] = - traversal.baseType.referencedTypeDecl - - /** Direct and transitive base type declaration. - */ - def baseTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst) - - /** Traverse to alias type declarations. - */ - def isAlias: Iterator[TypeDecl] = - traversal.filter(_.aliasTypeFullName.isDefined) - - /** Traverse to canonical type declarations. - */ - def isCanonical: Iterator[TypeDecl] = - traversal.filter(_.aliasTypeFullName.isEmpty) - - /** If this is an alias type declaration, go to its underlying type declaration else unchanged. - */ - def unravelAlias: Iterator[TypeDecl] = { - traversal.map { typeDecl => - val alias = for { - tpe <- typeDecl.aliasedType.nextOption() - typeDecl <- tpe.referencedTypeDecl.nextOption() - } yield typeDecl - - alias.getOrElse(typeDecl) - } - } - - /** Traverse to canonical type which means unravel aliases until we find a non alias type declaration. - */ - def canonicalType: Iterator[TypeDecl] = - traversal.repeat(_.unravelAlias)(_.until(_.isCanonical).maxDepth(maxAliasExpansions)) - - /** Direct alias type declarations. - */ - def aliasTypeDecl: Iterator[TypeDecl] = - referencingType.aliasTypeDecl - - /** Direct and transitive alias type declarations. - */ - def aliasTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst) - -} - -object TypeDeclTraversal { - private val maxAliasExpansions = 100 -} +class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal: + import TypeDeclTraversal.* + + /** Annotations of the type declaration + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) + + /** Types referencing to this type declaration. + */ + def referencingType: Iterator[Type] = + traversal.flatMap(_.refIn) + + /** Namespace in which this type declaration is defined + */ + def namespace: Iterator[Namespace] = + traversal.flatMap(_.namespaceBlock).namespace + + /** Methods defined as part of this type + */ + def method: Iterator[Method] = + canonicalType.flatMap(_._methodViaAstOut) + + /** Filter for type declarations contained in the analyzed code. + */ + def internal: Iterator[TypeDecl] = + canonicalType.isExternal(false) + + /** Filter for type declarations not contained in the analyzed code. + */ + def external: Iterator[TypeDecl] = + canonicalType.isExternal(true) + + /** Member variables + */ + def member: Iterator[Member] = + canonicalType.flatMap(_._memberViaAstOut) + + /** Direct base types in the inheritance graph. + */ + def baseType: Iterator[Type] = + canonicalType.flatMap(_._typeViaInheritsFromOut) + + /** Direct base type declaration. + */ + def derivedTypeDecl: Iterator[TypeDecl] = + referencingType.derivedTypeDecl + + /** Direct and transitive base type declaration. + */ + def derivedTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst) + + /** Direct base type declaration. + */ + def baseTypeDecl: Iterator[TypeDecl] = + traversal.baseType.referencedTypeDecl + + /** Direct and transitive base type declaration. + */ + def baseTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst) + + /** Traverse to alias type declarations. + */ + def isAlias: Iterator[TypeDecl] = + traversal.filter(_.aliasTypeFullName.isDefined) + + /** Traverse to canonical type declarations. + */ + def isCanonical: Iterator[TypeDecl] = + traversal.filter(_.aliasTypeFullName.isEmpty) + + /** If this is an alias type declaration, go to its underlying type declaration else unchanged. + */ + def unravelAlias: Iterator[TypeDecl] = + traversal.map { typeDecl => + val alias = + for + tpe <- typeDecl.aliasedType.nextOption() + typeDecl <- tpe.referencedTypeDecl.nextOption() + yield typeDecl + + alias.getOrElse(typeDecl) + } + + /** Traverse to canonical type which means unravel aliases until we find a non alias type + * declaration. + */ + def canonicalType: Iterator[TypeDecl] = + traversal.repeat(_.unravelAlias)(_.until(_.isCanonical).maxDepth(maxAliasExpansions)) + + /** Direct alias type declarations. + */ + def aliasTypeDecl: Iterator[TypeDecl] = + referencingType.aliasTypeDecl + + /** Direct and transitive alias type declarations. + */ + def aliasTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst) +end TypeDeclTraversal + +object TypeDeclTraversal: + private val maxAliasExpansions = 100 diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala index ea2bf4e0..eaedbd0c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala @@ -4,91 +4,90 @@ import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* -class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal { - - /** Annotations of the corresponding type declaration. - */ - def annotation: Iterator[nodes.Annotation] = - traversal.referencedTypeDecl.annotation - - /** Namespaces in which the corresponding type declaration is defined. - */ - def namespace: Iterator[Namespace] = - traversal.referencedTypeDecl.namespace - - /** Methods defined on the corresponding type declaration. - */ - def method: Iterator[Method] = - traversal.referencedTypeDecl.method - - /** Filter for types whos corresponding type declaration is in the analyzed jar. - */ - def internal: Iterator[Type] = - traversal.where(_.referencedTypeDecl.internal) - - /** Filter for types whos corresponding type declaration is not in the analyzed jar. - */ - def external: Iterator[Type] = - traversal.where(_.referencedTypeDecl.external) - - /** Member variables of the corresponding type declaration. - */ - def member: Iterator[Member] = - traversal.referencedTypeDecl.member - - /** Direct base types of the corresponding type declaration in the inheritance graph. - */ - def baseType: Iterator[Type] = - traversal.referencedTypeDecl.baseType - - /** Direct and transitive base types of the corresponding type declaration. - */ - def baseTypeTransitive: Iterator[Type] = - traversal.repeat(_.baseType)(_.emitAllButFirst) - - /** Direct derived types. - */ - def derivedType: Iterator[Type] = - derivedTypeDecl.referencingType - - /** Direct and transitive derived types. - */ - def derivedTypeTransitive: Iterator[Type] = - traversal.repeat(_.derivedType)(_.emitAllButFirst) - - /** Type declarations which derive from this type. - */ - def derivedTypeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.inheritsFromIn) - - /** Direct alias types. - */ - def aliasType: Iterator[Type] = - traversal.aliasTypeDecl.referencingType - - /** Direct and transitive alias types. - */ - def aliasTypeTransitive: Iterator[Type] = - traversal.repeat(_.aliasType)(_.emitAllButFirst) - - def localOfType: Iterator[Local] = - traversal.flatMap(_._localViaEvalTypeIn) - - def memberOfType: Iterator[Member] = - traversal.flatMap(_.evalTypeIn).collectAll[Member] - - @deprecated("Please use `parameterOfType`") - def parameter: Iterator[MethodParameterIn] = parameterOfType - - def parameterOfType: Iterator[MethodParameterIn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodParameterIn] - - def methodReturnOfType: Iterator[MethodReturn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodReturn] - - def expressionOfType: Iterator[Expression] = expression - - def expression: Iterator[Expression] = - traversal.flatMap(_.evalTypeIn).collectAll[Expression] - -} +class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal: + + /** Annotations of the corresponding type declaration. + */ + def annotation: Iterator[nodes.Annotation] = + traversal.referencedTypeDecl.annotation + + /** Namespaces in which the corresponding type declaration is defined. + */ + def namespace: Iterator[Namespace] = + traversal.referencedTypeDecl.namespace + + /** Methods defined on the corresponding type declaration. + */ + def method: Iterator[Method] = + traversal.referencedTypeDecl.method + + /** Filter for types whos corresponding type declaration is in the analyzed jar. + */ + def internal: Iterator[Type] = + traversal.where(_.referencedTypeDecl.internal) + + /** Filter for types whos corresponding type declaration is not in the analyzed jar. + */ + def external: Iterator[Type] = + traversal.where(_.referencedTypeDecl.external) + + /** Member variables of the corresponding type declaration. + */ + def member: Iterator[Member] = + traversal.referencedTypeDecl.member + + /** Direct base types of the corresponding type declaration in the inheritance graph. + */ + def baseType: Iterator[Type] = + traversal.referencedTypeDecl.baseType + + /** Direct and transitive base types of the corresponding type declaration. + */ + def baseTypeTransitive: Iterator[Type] = + traversal.repeat(_.baseType)(_.emitAllButFirst) + + /** Direct derived types. + */ + def derivedType: Iterator[Type] = + derivedTypeDecl.referencingType + + /** Direct and transitive derived types. + */ + def derivedTypeTransitive: Iterator[Type] = + traversal.repeat(_.derivedType)(_.emitAllButFirst) + + /** Type declarations which derive from this type. + */ + def derivedTypeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.inheritsFromIn) + + /** Direct alias types. + */ + def aliasType: Iterator[Type] = + traversal.aliasTypeDecl.referencingType + + /** Direct and transitive alias types. + */ + def aliasTypeTransitive: Iterator[Type] = + traversal.repeat(_.aliasType)(_.emitAllButFirst) + + def localOfType: Iterator[Local] = + traversal.flatMap(_._localViaEvalTypeIn) + + def memberOfType: Iterator[Member] = + traversal.flatMap(_.evalTypeIn).collectAll[Member] + + @deprecated("Please use `parameterOfType`") + def parameter: Iterator[MethodParameterIn] = parameterOfType + + def parameterOfType: Iterator[MethodParameterIn] = + traversal.flatMap(_.evalTypeIn).collectAll[MethodParameterIn] + + def methodReturnOfType: Iterator[MethodReturn] = + traversal.flatMap(_.evalTypeIn).collectAll[MethodReturn] + + def expressionOfType: Iterator[Expression] = expression + + def expression: Iterator[Expression] = + traversal.flatMap(_.evalTypeIn).collectAll[Expression] +end TypeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala index e9fa1d97..8e1bac2a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala @@ -7,56 +7,54 @@ import io.shiftleft.passes.CpgPassBase import io.shiftleft.semanticcpg.Overlays import org.slf4j.{Logger, LoggerFactory} -abstract class LayerCreator { - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - val overlayName: String - val description: String - val dependsOn: List[String] = List() - - /** If the LayerCreator modifies the CPG, then we store its name in the CPGs metadata and disallow rerunning the - * creator, that is, applying the layer twice. - */ - protected val storeOverlayName: Boolean = true - - def run(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit = { - val appliedOverlays = Overlays.appliedOverlays(context.cpg).toSet - if (!dependsOn.toSet.subsetOf(appliedOverlays)) { - logger.warn( - s"${this.getClass.getName} depends on $dependsOn but CPG only has $appliedOverlays - skipping creation" - ) - } else if (appliedOverlays.contains(overlayName)) { - logger.warn(s"The overlay $overlayName already exists - skipping creation") - } else { - create(context, storeUndoInfo) - if (storeOverlayName) { - Overlays.appendOverlayName(context.cpg, overlayName) - } - } - } - - protected def initSerializedCpg(outputDir: Option[String], passName: String, index: Int = 0): SerializedCpg = { - outputDir match { - case Some(dir) => new SerializedCpg((File(dir) / s"${index}_$passName").path.toAbsolutePath.toString) - case None => new SerializedCpg() - } - } - - protected def runPass( - pass: CpgPassBase, - context: LayerCreatorContext, - storeUndoInfo: Boolean, - index: Int = 0 - ): Unit = { - val serializedCpg = initSerializedCpg(context.outputDir, pass.name, index) - pass.createApplySerializeAndStore(serializedCpg, inverse = storeUndoInfo) - serializedCpg.close() - } - - def create(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit - -} +abstract class LayerCreator: + + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + val overlayName: String + val description: String + val dependsOn: List[String] = List() + + /** If the LayerCreator modifies the CPG, then we store its name in the CPGs metadata and + * disallow rerunning the creator, that is, applying the layer twice. + */ + protected val storeOverlayName: Boolean = true + + def run(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit = + val appliedOverlays = Overlays.appliedOverlays(context.cpg).toSet + if !dependsOn.toSet.subsetOf(appliedOverlays) then + logger.warn( + s"${this.getClass.getName} depends on $dependsOn but CPG only has $appliedOverlays - skipping creation" + ) + else if appliedOverlays.contains(overlayName) then + logger.warn(s"The overlay $overlayName already exists - skipping creation") + else + create(context, storeUndoInfo) + if storeOverlayName then + Overlays.appendOverlayName(context.cpg, overlayName) + + protected def initSerializedCpg( + outputDir: Option[String], + passName: String, + index: Int = 0 + ): SerializedCpg = + outputDir match + case Some(dir) => + new SerializedCpg((File(dir) / s"${index}_$passName").path.toAbsolutePath.toString) + case None => new SerializedCpg() + + protected def runPass( + pass: CpgPassBase, + context: LayerCreatorContext, + storeUndoInfo: Boolean, + index: Int = 0 + ): Unit = + val serializedCpg = initSerializedCpg(context.outputDir, pass.name, index) + pass.createApplySerializeAndStore(serializedCpg, inverse = storeUndoInfo) + serializedCpg.close() + + def create(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit +end LayerCreator class LayerCreatorContext(val cpg: Cpg, val outputDir: Option[String] = None) {} class LayerCreatorOptions() diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala index 79ee059b..6ed1e50c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/package.scala @@ -2,14 +2,16 @@ package io.shiftleft /** Domain specific language for querying code property graphs * - * This is the API reference for the CPG query language, a language to mine code for defects and vulnerabilities both - * interactively on a code analysis shell (REPL), or using non-interactive scripts. + * This is the API reference for the CPG query language, a language to mine code for defects and + * vulnerabilities both interactively on a code analysis shell (REPL), or using non-interactive + * scripts. * * Queries written in the CPG query language express graph traversals (see - * [[https://en.wikipedia.org/wiki/Graph_traversal]]). Similar to the standard graph traversal language "Gremlin" (see - * [[https://en.wikipedia.org/wiki/Gremlin_(programming_language)]])) these traversals are formulated as sequences of - * primitive language elements referred to as "steps". You can think of a step as a small program, similar to a unix - * shell utility, however, instead of processing lines one by one, the step processes nodes of the graph. + * [[https://en.wikipedia.org/wiki/Graph_traversal]]). Similar to the standard graph traversal + * language "Gremlin" (see [[https://en.wikipedia.org/wiki/Gremlin_(programming_language)]])) these + * traversals are formulated as sequences of primitive language elements referred to as "steps". + * You can think of a step as a small program, similar to a unix shell utility, however, instead of + * processing lines one by one, the step processes nodes of the graph. * * ==Starting a traversal== * All traversals begin by selecting a set of start nodes, e.g., @@ -24,8 +26,8 @@ package io.shiftleft * {{{io.shiftleft.codepropertygraph.Cpg}}} * * ==Lazy evaluation== - * Queries are lazily evaluated, e.g., `cpg.method` creates a traversal which you can add more steps to. You can, for - * example, evaluate the traversal by converting it to a list: + * Queries are lazily evaluated, e.g., `cpg.method` creates a traversal which you can add more + * steps to. You can, for example, evaluate the traversal by converting it to a list: * * {{{cpg.method.toList}}} * @@ -36,8 +38,9 @@ package io.shiftleft * provides the same result as the former query. * * ==Properties== - * Nodes have "properties", key-value pairs where keys are strings and values are primitive data types such as strings, - * integers, or Booleans. Properties of nodes can be selected based on their key, e.g., + * Nodes have "properties", key-value pairs where keys are strings and values are primitive data + * types such as strings, integers, or Booleans. Properties of nodes can be selected based on their + * key, e.g., * * {{{cpg.method.name}}} * @@ -45,22 +48,24 @@ package io.shiftleft * * {{{cpg.method.name(".*exec.*")}}} * - * traverse to all methods where `name` matches the regular expression ".*exec.*". You can see a complete list of - * properties by browsing to the API documentation of the corresponding step. For example, you can find the properties - * of method nodes at [[io.shiftleft.semanticcpg.language.types.structure.MethodTraversal]]. + * traverse to all methods where `name` matches the regular expression ".*exec.*". You can see a + * complete list of properties by browsing to the API documentation of the corresponding step. For + * example, you can find the properties of method nodes at + * [[io.shiftleft.semanticcpg.language.types.structure.MethodTraversal]]. * * ==Side effects== - * Useful if you want to mutate something outside the traversal, or simply debug it: This prints all typeDecl names as - * it traverses the graph and increments `i` for each one. + * Useful if you want to mutate something outside the traversal, or simply debug it: This prints + * all typeDecl names as it traverses the graph and increments `i` for each one. * {{{ * var i = 0 * cpg.typeDecl.sideEffect{typeTemplate => println(typeTemplate.name); i = i + 1}.exec * }}} * * ==[advanced] Selecting multiple things from your traversal== - * If you are interested in multiple things along the way of your traversal, you label anything using the `as` - * modulator, and use `select` at the end. Note that the compiler automatically derived the correct return type as a - * tuple of the labelled steps, in this case with two elements. + * If you are interested in multiple things along the way of your traversal, you label anything + * using the `as` modulator, and use `select` at the end. Note that the compiler automatically + * derived the correct return type as a tuple of the labelled steps, in this case with two + * elements. * * {{{ * cpg.method.as("method").definingTypeDecl.as("classDef").select.toList @@ -74,8 +79,9 @@ package io.shiftleft * someMethod.start.parameter.toList * }}} * - * You can use this e.g. in a for comprehension, which is (in this context) essentially an alternative way to select - * multiple intermediate things. It is more expressive, but more computationally expensive. + * You can use this e.g. in a for comprehension, which is (in this context) essentially an + * alternative way to select multiple intermediate things. It is more expressive, but more + * computationally expensive. * {{{ * val query = for { * method <- cpg.method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala index a437835d..84c572f9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala @@ -6,47 +6,55 @@ import overflowdb.{Edge, Node, Property, PropertyKey} import java.util /** mixin trait for test nodes */ -trait DummyNodeImpl extends StoredNode { - // Members declared in overflowdb.Element - def graph(): overflowdb.Graph = ??? - def property[A](x$1: overflowdb.PropertyKey[A]): A = ??? - def property(x$1: String): Object = ??? - def propertyKeys(): java.util.Set[String] = ??? - def propertiesMap(): java.util.Map[String, Object] = ??? - def propertyOption(x$1: String): java.util.Optional[Object] = ??? - def propertyOption[A](x$1: overflowdb.PropertyKey[A]): java.util.Optional[A] = ??? - override def addEdgeImpl(label: String, inNode: Node, keyValues: Any*): Edge = ??? - override def addEdgeImpl(label: String, inNode: Node, keyValues: util.Map[String, AnyRef]): Edge = ??? - override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: Any*): Unit = ??? - override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: util.Map[String, AnyRef]): Unit = ??? - override def setPropertyImpl(key: String, value: Any): Unit = ??? - override def setPropertyImpl[A](key: PropertyKey[A], value: A): Unit = ??? - override def setPropertyImpl(property: Property[_]): Unit = ??? - override def removePropertyImpl(key: String): Unit = ??? - override def removeImpl(): Unit = ??? +trait DummyNodeImpl extends StoredNode: + // Members declared in overflowdb.Element + def graph(): overflowdb.Graph = ??? + def property[A](x$1: overflowdb.PropertyKey[A]): A = ??? + def property(x$1: String): Object = ??? + def propertyKeys(): java.util.Set[String] = ??? + def propertiesMap(): java.util.Map[String, Object] = ??? + def propertyOption(x$1: String): java.util.Optional[Object] = ??? + def propertyOption[A](x$1: overflowdb.PropertyKey[A]): java.util.Optional[A] = ??? + override def addEdgeImpl(label: String, inNode: Node, keyValues: Any*): Edge = ??? + override def addEdgeImpl( + label: String, + inNode: Node, + keyValues: util.Map[String, AnyRef] + ): Edge = ??? + override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: Any*): Unit = ??? + override def addEdgeSilentImpl( + label: String, + inNode: Node, + keyValues: util.Map[String, AnyRef] + ): Unit = ??? + override def setPropertyImpl(key: String, value: Any): Unit = ??? + override def setPropertyImpl[A](key: PropertyKey[A], value: A): Unit = ??? + override def setPropertyImpl(property: Property[?]): Unit = ??? + override def removePropertyImpl(key: String): Unit = ??? + override def removeImpl(): Unit = ??? - // Members declared in scala.Equals - def canEqual(that: Any): Boolean = ??? + // Members declared in scala.Equals + def canEqual(that: Any): Boolean = ??? - def both(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def both(): java.util.Iterator[overflowdb.Node] = ??? - def bothE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def bothE(): java.util.Iterator[overflowdb.Edge] = ??? - def id(): Long = ??? - def in(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def in(): java.util.Iterator[overflowdb.Node] = ??? - def inE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def inE(): java.util.Iterator[overflowdb.Edge] = ??? - def out(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def out(): java.util.Iterator[overflowdb.Node] = ??? - def outE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def outE(): java.util.Iterator[overflowdb.Edge] = ??? + def both(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def both(): java.util.Iterator[overflowdb.Node] = ??? + def bothE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def bothE(): java.util.Iterator[overflowdb.Edge] = ??? + def id(): Long = ??? + def in(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def in(): java.util.Iterator[overflowdb.Node] = ??? + def inE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def inE(): java.util.Iterator[overflowdb.Edge] = ??? + def out(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def out(): java.util.Iterator[overflowdb.Node] = ??? + def outE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def outE(): java.util.Iterator[overflowdb.Edge] = ??? - // Members declared in scala.Product - def productArity: Int = ??? - def productElement(n: Int): Any = ??? + // Members declared in scala.Product + def productArity: Int = ??? + def productElement(n: Int): Any = ??? - // Members declared in io.shiftleft.codepropertygraph.generated.nodes.StoredNode - def productElementLabel(n: Int): String = ??? - def valueMap: java.util.Map[String, AnyRef] = ??? -} + // Members declared in io.shiftleft.codepropertygraph.generated.nodes.StoredNode + def productElementLabel(n: Int): String = ??? + def valueMap: java.util.Map[String, AnyRef] = ??? +end DummyNodeImpl diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala index 7e31100f..18af5f03 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala @@ -1,215 +1,215 @@ package io.shiftleft.semanticcpg import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Languages, ModifierTypes} import io.shiftleft.passes.CpgPass -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder -package object testing { - - object MockCpg { - - def apply(): MockCpg = new MockCpg - - def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) - } - - case class MockCpg(cpg: Cpg = Cpg.emptyCpg) { - - def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) - - def withMetaData(language: String, overlays: List[String]): MockCpg = { - withCustom { (diffGraph, _) => - diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) - } - } - - def withFile(filename: String): MockCpg = - withCustom { (graph, _) => - graph.addNode(NewFile().name(filename)) - } - - def withNamespace(name: String, inFile: Option[String] = None): MockCpg = - withCustom { (graph, _) => - { - val namespaceBlock = NewNamespaceBlock().name(name) - val namespace = NewNamespace().name(name) - graph.addNode(namespaceBlock) - graph.addNode(namespace) - graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) - if (inFile.isDefined) { - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) - } - } - } - - def withTypeDecl( - name: String, - isExternal: Boolean = false, - inNamespace: Option[String] = None, - inFile: Option[String] = None - ): MockCpg = - withCustom { (graph, _) => - { - val typeNode = NewType().name(name) - val typeDeclNode = NewTypeDecl() - .name(name) - .fullName(name) - .isExternal(isExternal) - - val member = NewMember().name("amember") - val modifier = NewModifier().modifierType(ModifierTypes.STATIC) - - graph.addNode(typeDeclNode) - graph.addNode(typeNode) - graph.addNode(member) - graph.addNode(modifier) - graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) - graph.addEdge(typeDeclNode, member, EdgeTypes.AST) - graph.addEdge(member, modifier, EdgeTypes.AST) - - if (inNamespace.isDefined) { - val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head - graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) - } - if (inFile.isDefined) { - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) - } - } - } - - def withMethod( - name: String, - external: Boolean = false, - inTypeDecl: Option[String] = None, - fileName: String = "" - ): MockCpg = - withCustom { (graph, _) => - val retParam = NewMethodReturn().typeFullName("int").order(10) - val param = NewMethodParameterIn().order(1).index(1).name("param1") - val paramType = NewType().name("paramtype") - val paramOut = NewMethodParameterOut().name("param1").order(1) - val method = - NewMethod().isExternal(external).name(name).fullName(name).signature("asignature").filename(fileName) - val block = NewBlock().typeFullName("int") - val modifier = NewModifier().modifierType("modifiertype") - - graph.addNode(method) - graph.addNode(retParam) - graph.addNode(param) - graph.addNode(paramType) - graph.addNode(paramOut) - graph.addNode(block) - graph.addNode(modifier) - graph.addEdge(method, retParam, EdgeTypes.AST) - graph.addEdge(method, param, EdgeTypes.AST) - graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) - graph.addEdge(method, block, EdgeTypes.AST) - graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(method, modifier, EdgeTypes.AST) - - if (inTypeDecl.isDefined) { - val typeDeclNode = cpg.typeDecl.name(inTypeDecl.get).head - graph.addEdge(typeDeclNode, method, EdgeTypes.AST) - } - } - - def withTagsOnMethod( - methodName: String, - methodTags: List[(String, String)] = List(), - paramTags: List[(String, String)] = List() - ): MockCpg = - withCustom { (graph, cpg) => - implicit val diffGraph: DiffGraphBuilder = graph - methodTags.foreach { case (k, v) => - cpg.method.name(methodName).newTagNodePair(k, v).store()(diffGraph) - } - paramTags.foreach { case (k, v) => - cpg.method.name(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) - } - } - - def withCallInMethod(methodName: String, callName: String, code: Option[String] = None): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val callNode = NewCall().name(callName).code(code.getOrElse(callName)) - graph.addNode(callNode) - graph.addEdge(blockNode, callNode, EdgeTypes.AST) - graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) - } - - def withMethodCall(calledMethod: String, callingMethod: String, code: Option[String] = None): MockCpg = - withCustom { (graph, cpg) => - val callingMethodNode = cpg.method.name(callingMethod).head - val calledMethodNode = cpg.method.name(calledMethod).head - val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) - graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) - graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) - } - - def withLocalInMethod(methodName: String, localName: String): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val typeNode = NewType().name("alocaltype") - val localNode = NewLocal().name(localName).typeFullName("alocaltype") - graph.addNode(localNode) - graph.addNode(typeNode) - graph.addEdge(blockNode, localNode, EdgeTypes.AST) - graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) - } - - def withLiteralArgument(callName: String, literalCode: String): MockCpg = { - withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val literalNode = NewLiteral().code(literalCode) - val typeDecl = NewTypeDecl() - .name("ATypeDecl") - .fullName("ATypeDecl") - - graph.addNode(typeDecl) - graph.addNode(literalNode) - graph.addEdge(callNode, literalNode, EdgeTypes.AST) - graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) - } - } - - def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = - withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) - - def withCallArgument(callName: String, callArgName: String, code: String = "", index: Int = 1): MockCpg = - withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) - - def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val typeDecl = NewTypeDecl().name("abc") - graph.addEdge(callNode, newNode, EdgeTypes.AST) - graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) - graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) - graph.addEdge(newNode, typeDecl, EdgeTypes.REF) - graph.addNode(newNode) - } - - def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = { - val diffGraph = new DiffGraphBuilder - f(diffGraph, cpg) - class MyPass extends CpgPass(cpg) { - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = { - builder.absorb(diffGraph) +package object testing: + + object MockCpg: + + def apply(): MockCpg = new MockCpg + + def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) + + case class MockCpg(cpg: Cpg = Cpg.emptyCpg): + + def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) + + def withMetaData(language: String, overlays: List[String]): MockCpg = + withCustom { (diffGraph, _) => + diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) + } + + def withFile(filename: String): MockCpg = + withCustom { (graph, _) => + graph.addNode(NewFile().name(filename)) + } + + def withNamespace(name: String, inFile: Option[String] = None): MockCpg = + withCustom { (graph, _) => + val namespaceBlock = NewNamespaceBlock().name(name) + val namespace = NewNamespace().name(name) + graph.addNode(namespaceBlock) + graph.addNode(namespace) + graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) + if inFile.isDefined then + val fileNode = cpg.file.name(inFile.get).head + graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) + } + + def withTypeDecl( + name: String, + isExternal: Boolean = false, + inNamespace: Option[String] = None, + inFile: Option[String] = None + ): MockCpg = + withCustom { (graph, _) => + val typeNode = NewType().name(name) + val typeDeclNode = NewTypeDecl() + .name(name) + .fullName(name) + .isExternal(isExternal) + + val member = NewMember().name("amember") + val modifier = NewModifier().modifierType(ModifierTypes.STATIC) + + graph.addNode(typeDeclNode) + graph.addNode(typeNode) + graph.addNode(member) + graph.addNode(modifier) + graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) + graph.addEdge(typeDeclNode, member, EdgeTypes.AST) + graph.addEdge(member, modifier, EdgeTypes.AST) + + if inNamespace.isDefined then + val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head + graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) + if inFile.isDefined then + val fileNode = cpg.file.name(inFile.get).head + graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) + } + + def withMethod( + name: String, + external: Boolean = false, + inTypeDecl: Option[String] = None, + fileName: String = "" + ): MockCpg = + withCustom { (graph, _) => + val retParam = NewMethodReturn().typeFullName("int").order(10) + val param = NewMethodParameterIn().order(1).index(1).name("param1") + val paramType = NewType().name("paramtype") + val paramOut = NewMethodParameterOut().name("param1").order(1) + val method = + NewMethod().isExternal(external).name(name).fullName(name).signature( + "asignature" + ).filename(fileName) + val block = NewBlock().typeFullName("int") + val modifier = NewModifier().modifierType("modifiertype") + + graph.addNode(method) + graph.addNode(retParam) + graph.addNode(param) + graph.addNode(paramType) + graph.addNode(paramOut) + graph.addNode(block) + graph.addNode(modifier) + graph.addEdge(method, retParam, EdgeTypes.AST) + graph.addEdge(method, param, EdgeTypes.AST) + graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) + graph.addEdge(method, block, EdgeTypes.AST) + graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(method, modifier, EdgeTypes.AST) + + if inTypeDecl.isDefined then + val typeDeclNode = cpg.typeDecl.name(inTypeDecl.get).head + graph.addEdge(typeDeclNode, method, EdgeTypes.AST) + } + + def withTagsOnMethod( + methodName: String, + methodTags: List[(String, String)] = List(), + paramTags: List[(String, String)] = List() + ): MockCpg = + withCustom { (graph, cpg) => + implicit val diffGraph: DiffGraphBuilder = graph + methodTags.foreach { case (k, v) => + cpg.method.name(methodName).newTagNodePair(k, v).store()(diffGraph) + } + paramTags.foreach { case (k, v) => + cpg.method.name(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) + } + } + + def withCallInMethod( + methodName: String, + callName: String, + code: Option[String] = None + ): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method.name(methodName).head + val blockNode = methodNode.block + val callNode = NewCall().name(callName).code(code.getOrElse(callName)) + graph.addNode(callNode) + graph.addEdge(blockNode, callNode, EdgeTypes.AST) + graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) + } + + def withMethodCall( + calledMethod: String, + callingMethod: String, + code: Option[String] = None + ): MockCpg = + withCustom { (graph, cpg) => + val callingMethodNode = cpg.method.name(callingMethod).head + val calledMethodNode = cpg.method.name(calledMethod).head + val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) + graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) + graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) + } + + def withLocalInMethod(methodName: String, localName: String): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method.name(methodName).head + val blockNode = methodNode.block + val typeNode = NewType().name("alocaltype") + val localNode = NewLocal().name(localName).typeFullName("alocaltype") + graph.addNode(localNode) + graph.addNode(typeNode) + graph.addEdge(blockNode, localNode, EdgeTypes.AST) + graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) + } + + def withLiteralArgument(callName: String, literalCode: String): MockCpg = + withCustom { (graph, cpg) => + val callNode = cpg.call.name(callName).head + val methodNode = callNode.method + val literalNode = NewLiteral().code(literalCode) + val typeDecl = NewTypeDecl() + .name("ATypeDecl") + .fullName("ATypeDecl") + + graph.addNode(typeDecl) + graph.addNode(literalNode) + graph.addEdge(callNode, literalNode, EdgeTypes.AST) + graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) + } + + def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = + withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) + + def withCallArgument( + callName: String, + callArgName: String, + code: String = "", + index: Int = 1 + ): MockCpg = + withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) + + def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => + val callNode = cpg.call.name(callName).head + val methodNode = callNode.method + val typeDecl = NewTypeDecl().name("abc") + graph.addEdge(callNode, newNode, EdgeTypes.AST) + graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) + graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) + graph.addEdge(newNode, typeDecl, EdgeTypes.REF) + graph.addNode(newNode) } - } - new MyPass().createAndApply() - this - } - } -} + def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = + val diffGraph = new DiffGraphBuilder + f(diffGraph, cpg) + class MyPass extends CpgPass(cpg): + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = + builder.absorb(diffGraph) + new MyPass().createAndApply() + this + end MockCpg +end testing diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala index 74f78799..6ea7053c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala @@ -7,9 +7,9 @@ import me.shadaj.scalapy.interpreter.CPythonInterpreter import scala.util.{Failure, Success, Try} -object Fingerprinting { +object Fingerprinting: - CPythonInterpreter.execManyLines(""" + CPythonInterpreter.execManyLines(""" |from hashlib import blake2b | |def calculate_hash(content, digest_size = 16): @@ -21,6 +21,5 @@ object Fingerprinting { | return None |""".stripMargin) - def calculate_hash(content: String, digest_size: Int = 16): Option[String] = - Option(py.Dynamic.global.calculate_hash(content, digest_size).as[String]) -} + def calculate_hash(content: String, digest_size: Int = 16): Option[String] = + Option(py.Dynamic.global.calculate_hash(content, digest_size).as[String]) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala index d27f04c8..74bef7a1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala @@ -2,37 +2,34 @@ package io.shiftleft.semanticcpg.utils import io.shiftleft.codepropertygraph.generated.Operators -object MemberAccess { +object MemberAccess: - /** For a given name, determine whether it is the name of a "member access" operation, e.g., - * ".memberAccess". - */ - def isGenericMemberAccessName(name: String): Boolean = { - (name == Operators.memberAccess) || - (name == Operators.indirectComputedMemberAccess) || - (name == Operators.indirectMemberAccess) || - (name == Operators.computedMemberAccess) || - (name == Operators.indirection) || - (name == Operators.addressOf) || - (name == Operators.fieldAccess) || - (name == Operators.indirectFieldAccess) || - (name == Operators.indexAccess) || - (name == Operators.indirectIndexAccess) || - (name == Operators.pointerShift) || - (name == Operators.getElementPtr) - } + /** For a given name, determine whether it is the name of a "member access" operation, e.g., + * ".memberAccess". + */ + def isGenericMemberAccessName(name: String): Boolean = + (name == Operators.memberAccess) || + (name == Operators.indirectComputedMemberAccess) || + (name == Operators.indirectMemberAccess) || + (name == Operators.computedMemberAccess) || + (name == Operators.indirection) || + (name == Operators.addressOf) || + (name == Operators.fieldAccess) || + (name == Operators.indirectFieldAccess) || + (name == Operators.indexAccess) || + (name == Operators.indirectIndexAccess) || + (name == Operators.pointerShift) || + (name == Operators.getElementPtr) - def isFieldAccess(name: String): Boolean = { - (name == Operators.memberAccess) || - (name == Operators.indirectComputedMemberAccess) || - (name == Operators.indirectMemberAccess) || - (name == Operators.computedMemberAccess) || - (name == Operators.indirection) || - (name == Operators.fieldAccess) || - (name == Operators.indirectFieldAccess) || - (name == Operators.indexAccess) || - (name == Operators.indirectIndexAccess) || - (name == Operators.getElementPtr) - } - -} + def isFieldAccess(name: String): Boolean = + (name == Operators.memberAccess) || + (name == Operators.indirectComputedMemberAccess) || + (name == Operators.indirectMemberAccess) || + (name == Operators.computedMemberAccess) || + (name == Operators.indirection) || + (name == Operators.fieldAccess) || + (name == Operators.indirectFieldAccess) || + (name == Operators.indexAccess) || + (name == Operators.indirectIndexAccess) || + (name == Operators.getElementPtr) +end MemberAccess diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala index d8ab2518..8a18ff7e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala @@ -4,22 +4,20 @@ import javax.xml.parsers.SAXParserFactory import scala.util.Try import scala.xml.{Elem, XML} -object SecureXmlParsing { - def parseXml(content: String): Option[Elem] = { - Try { - val spf = SAXParserFactory.newInstance() +object SecureXmlParsing: + def parseXml(content: String): Option[Elem] = + Try { + val spf = SAXParserFactory.newInstance() - spf.setValidating(false) - spf.setNamespaceAware(false) - spf.setXIncludeAware(false) - spf.setFeature("http://xml.org/sax/features/validation", false) - spf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", false) - spf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false) - spf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false) - spf.setFeature("http://xml.org/sax/features/external-parameter-entities", false) - spf.setFeature("http://xml.org/sax/features/external-general-entities", false) + spf.setValidating(false) + spf.setNamespaceAware(false) + spf.setXIncludeAware(false) + spf.setFeature("http://xml.org/sax/features/validation", false) + spf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", false) + spf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false) + spf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false) + spf.setFeature("http://xml.org/sax/features/external-parameter-entities", false) + spf.setFeature("http://xml.org/sax/features/external-general-entities", false) - XML.withSAXParser(spf.newSAXParser()).loadString(content) - }.toOption - } -} + XML.withSAXParser(spf.newSAXParser()).loadString(content) + }.toOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala index 31e229da..4922152e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala @@ -1,9 +1,8 @@ package io.shiftleft.semanticcpg.utils import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* -object Statements { - def countAll(cpg: Cpg): Long = - cpg.method.topLevelExpressions.size -} +object Statements: + def countAll(cpg: Cpg): Long = + cpg.method.topLevelExpressions.size diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala index f09e23d5..aca4c844 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala @@ -8,9 +8,9 @@ import overflowdb.formats.ExportResult import scala.util.{Failure, Success, Try} import java.nio.file.Path -object Torch { +object Torch: - CPythonInterpreter.execManyLines(""" + CPythonInterpreter.execManyLines(""" | |SCIENCE_PACK_AVAILABLE = True |try: @@ -19,76 +19,91 @@ object Torch { | SCIENCE_PACK_AVAILABLE = False |""".stripMargin) - def convert_graphml(gml_file: Path) = py.Dynamic.global.convert_graphml(gml_file.toAbsolutePath.toString) + def convert_graphml(gml_file: Path) = + py.Dynamic.global.convert_graphml(gml_file.toAbsolutePath.toString) - def to_pyg(gml_file: Path) = py.Dynamic.global.to_pyg(convert_graphml(gml_file)) + def to_pyg(gml_file: Path) = py.Dynamic.global.to_pyg(convert_graphml(gml_file)) - def diff_graph( - first_gml_file: Path, - second_gml_file: Path, - include_common: Boolean = false, - as_dict: Boolean = false - ) = { - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.diff_graph(first_graph, second_graph, include_common, as_dict) - } + def diff_graph( + first_gml_file: Path, + second_gml_file: Path, + include_common: Boolean = false, + as_dict: Boolean = false + ) = + val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global.diff_graph(first_graph, second_graph, include_common, as_dict) - def is_similar( - first_gml_file: Path, - second_gml_file: Path, - edit_distance: Int = 10, - upper_bound: Int = 500, - timeout: Int = 5 - ): Boolean = { - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global - .is_similar( - first_graph, - second_graph, - edit_distance = edit_distance, - upper_bound = upper_bound, - timeout = timeout - ) - .as[Boolean] - } + def is_similar( + first_gml_file: Path, + second_gml_file: Path, + edit_distance: Int = 10, + upper_bound: Int = 500, + timeout: Int = 5 + ): Boolean = + val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global + .is_similar( + first_graph, + second_graph, + edit_distance = edit_distance, + upper_bound = upper_bound, + timeout = timeout + ) + .as[Boolean] + end is_similar - def is_similar(first_result: ExportResult, second_result: ExportResult, edit_distance: Int): Boolean = { - if (first_result.files.nonEmpty && second_result.files.nonEmpty) { - val first_gml_file = first_result.files.head - val second_gml_file = second_result.files.head - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.is_similar(first_graph, second_graph, edit_distance = edit_distance).as[Boolean] - } else { - false - } - } + def is_similar( + first_result: ExportResult, + second_result: ExportResult, + edit_distance: Int + ): Boolean = + if first_result.files.nonEmpty && second_result.files.nonEmpty then + val first_gml_file = first_result.files.head + val second_gml_file = second_result.files.head + val first_graph = + py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global.is_similar( + first_graph, + second_graph, + edit_distance = edit_distance + ).as[Boolean] + else + false - def edit_distance( - first_result: ExportResult, - second_result: ExportResult, - upper_bound: Int = 500, - timeout: Int = 5 - ): Double = { - if (first_result.files.nonEmpty && second_result.files.nonEmpty) { - val first_gml_file = first_result.files.head - val second_gml_file = second_result.files.head - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.ged(first_graph, second_graph, upper_bound = upper_bound, timeout = timeout).as[Double] - } else { - -1.0 - } - } + def edit_distance( + first_result: ExportResult, + second_result: ExportResult, + upper_bound: Int = 500, + timeout: Int = 5 + ): Double = + if first_result.files.nonEmpty && second_result.files.nonEmpty then + val first_gml_file = first_result.files.head + val second_gml_file = second_result.files.head + val first_graph = + py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global.ged( + first_graph, + second_graph, + upper_bound = upper_bound, + timeout = timeout + ).as[Double] + else + -1.0 - def generate_sp_model( - filename: String, - vocab_size: Int = 20000, - model_type: String = "unigram", - model_prefix: String = "m_user" - ) = py.Dynamic.global.generate_sp_model(filename, vocab_size, model_type, model_prefix) + def generate_sp_model( + filename: String, + vocab_size: Int = 20000, + model_type: String = "unigram", + model_prefix: String = "m_user" + ) = py.Dynamic.global.generate_sp_model(filename, vocab_size, model_type, model_prefix) - def load_sp_model(filename: String) = py.Dynamic.global.load_sp_model(filename) -} + def load_sp_model(filename: String) = py.Dynamic.global.load_sp_model(filename) +end Torch