From f45fa443b13820baf060dfa67e8310cd5422c03f Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Fri, 24 Feb 2023 23:09:39 +0900 Subject: [PATCH 01/13] auto option rank: initial commit. possible top-down implementation startup point. Signed-off-by: Vivien Oddou --- src/AzslcIntermediateRepresentation.cpp | 25 +++++++++++++++++++++++++ src/AzslcIntermediateRepresentation.h | 6 ++++++ 2 files changed, 31 insertions(+) diff --git a/src/AzslcIntermediateRepresentation.cpp b/src/AzslcIntermediateRepresentation.cpp index 9cb5716..7e42ea5 100644 --- a/src/AzslcIntermediateRepresentation.cpp +++ b/src/AzslcIntermediateRepresentation.cpp @@ -951,4 +951,29 @@ namespace AZ::ShaderCompiler return memberList[memberList.size() - 1]; } + void IntermediateRepresentation::AnalyzeOptionRanks() + { + // loop over variables + for (auto& [uid, varInfo, kindInfo] : m_symbols.GetOrderedSymbolsOfSubType_3()) + { + // only options + if (varInfo->CheckHasStorageFlag(StorageFlag::Option)) + { + int impactScore = 0; + // loop over appearances over the program + for (Seenat& ref : kindInfo->GetSeenats()) + { + // determine an impact score + impactScore += AnalyzeImpact(ref.m_where); + } + } + } + } + + int IntermediateRepresentation::AnalyzeImpact(TokensLocation const& location) + { + // find the node at location: + return 0; + } + } // end of namespace AZ::ShaderCompiler diff --git a/src/AzslcIntermediateRepresentation.h b/src/AzslcIntermediateRepresentation.h index ab7c43b..3957cb2 100644 --- a/src/AzslcIntermediateRepresentation.h +++ b/src/AzslcIntermediateRepresentation.h @@ -305,6 +305,12 @@ namespace AZ::ShaderCompiler // If @structUid is not struct or class, then it returns nullptr. IdentifierUID GetLastMemberVariable(const IdentifierUID& structUid); + //! Determine a heurisitcal global order between options in the program, using "impacted code size" static analysis. + void AnalyzeOptionRanks(); + + //! Estimate a score proportional to how much code is "child" to the AST node at `location` + int AnalyzeImpact(TokensLocation const& location); + // the maps of all variables, functions, etc, from the source code (things with declarations and a name). SymbolAggregator m_symbols; // stateful helper during parsing From d52819aea5462ec704c58ea1e1acfcc7fe8a67a0 Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Fri, 10 Mar 2023 13:30:25 +0900 Subject: [PATCH 02/13] WIP: option rank static analysis. half working draft. function follow still not working. Signed-off-by: Vivien Oddou --- src/AzslcEmitter.cpp | 1 - src/AzslcIntermediateRepresentation.cpp | 27 ----- src/AzslcIntermediateRepresentation.h | 7 -- src/AzslcKindInfo.h | 1 + src/AzslcMain.cpp | 4 +- src/AzslcReflection.cpp | 146 +++++++++++++++++++++--- src/AzslcReflection.h | 16 ++- src/AzslcUtils.h | 19 ++- 8 files changed, 163 insertions(+), 58 deletions(-) diff --git a/src/AzslcEmitter.cpp b/src/AzslcEmitter.cpp index f800da9..605728d 100644 --- a/src/AzslcEmitter.cpp +++ b/src/AzslcEmitter.cpp @@ -1277,7 +1277,6 @@ namespace AZ::ShaderCompiler const ICodeEmissionMutator* codeMutator = m_codeMutator; ssize_t ii = interval.a; - bool wasInPreprocessorDirective = false; // record a state to detect exit of directives, because they need to reside on their own lines while (ii <= interval.b) { auto* token = GetNextToken(ii /*inout*/); diff --git a/src/AzslcIntermediateRepresentation.cpp b/src/AzslcIntermediateRepresentation.cpp index 7e42ea5..89a6c7f 100644 --- a/src/AzslcIntermediateRepresentation.cpp +++ b/src/AzslcIntermediateRepresentation.cpp @@ -332,7 +332,6 @@ namespace AZ::ShaderCompiler cout << " storage: " << sub.m_typeInfoExt.m_qualifiers.GetDisplayName() << "\n"; cout << " array dim: \"" << sub.m_typeInfoExt.m_arrayDims.ToString() << "\"\n"; cout << " has sampler state: " << (sub.m_samplerState ? "yes\n" : "no\n"); - cout << "\n"; if (!holds_alternative(sub.m_constVal)) { cout << " val: " << ExtractValueAsInt64(sub.m_constVal) << "\n"; @@ -950,30 +949,4 @@ namespace AZ::ShaderCompiler } return memberList[memberList.size() - 1]; } - - void IntermediateRepresentation::AnalyzeOptionRanks() - { - // loop over variables - for (auto& [uid, varInfo, kindInfo] : m_symbols.GetOrderedSymbolsOfSubType_3()) - { - // only options - if (varInfo->CheckHasStorageFlag(StorageFlag::Option)) - { - int impactScore = 0; - // loop over appearances over the program - for (Seenat& ref : kindInfo->GetSeenats()) - { - // determine an impact score - impactScore += AnalyzeImpact(ref.m_where); - } - } - } - } - - int IntermediateRepresentation::AnalyzeImpact(TokensLocation const& location) - { - // find the node at location: - return 0; - } - } // end of namespace AZ::ShaderCompiler diff --git a/src/AzslcIntermediateRepresentation.h b/src/AzslcIntermediateRepresentation.h index 3957cb2..42fc0fa 100644 --- a/src/AzslcIntermediateRepresentation.h +++ b/src/AzslcIntermediateRepresentation.h @@ -14,7 +14,6 @@ namespace AZ::ShaderCompiler { - //! We limit the maximum number of render targets to 8, with indices in the range [0..7] static const uint32_t kMaxRenderTargets = 8; @@ -305,12 +304,6 @@ namespace AZ::ShaderCompiler // If @structUid is not struct or class, then it returns nullptr. IdentifierUID GetLastMemberVariable(const IdentifierUID& structUid); - //! Determine a heurisitcal global order between options in the program, using "impacted code size" static analysis. - void AnalyzeOptionRanks(); - - //! Estimate a score proportional to how much code is "child" to the AST node at `location` - int AnalyzeImpact(TokensLocation const& location); - // the maps of all variables, functions, etc, from the source code (things with declarations and a name). SymbolAggregator m_symbols; // stateful helper during parsing diff --git a/src/AzslcKindInfo.h b/src/AzslcKindInfo.h index ef50968..35ae46a 100644 --- a/src/AzslcKindInfo.h +++ b/src/AzslcKindInfo.h @@ -791,6 +791,7 @@ namespace AZ::ShaderCompiler vector< IdentifierUID > m_overrides; //!< list of implementing functions in child classes optional< IdentifierUID > m_base; //!< points to the overridden function in the base interface, if applies. only supports one base FunctionMultiForwards m_multiFwds = FMF_None; //!< presence of redundant prototype-only declarations + int m_costScore = -1; //!< heuristical static analysis of the amount of instructions contained struct Parameter { IdentifierUID m_varId; diff --git a/src/AzslcMain.cpp b/src/AzslcMain.cpp index 973c6b5..79ec6f9 100644 --- a/src/AzslcMain.cpp +++ b/src/AzslcMain.cpp @@ -23,8 +23,8 @@ namespace StdFs = std::filesystem; // For large features or milestones. Minor version allows for breaking changes. Existing tests can change. #define AZSLC_MINOR "8" // last change: introduction of class inheritance // For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change. -#define AZSLC_REVISION "16" // last change: fixup runtime error with redundant function declarations - // "15" change: add min in option value key-extracter's function for range & enum +#define AZSLC_REVISION "17" // last change: automatic option ranks + // "16" change: fixup runtime error with redundant function declarations namespace AZ::ShaderCompiler { diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index 53d3e7a..b62f100 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -630,6 +630,8 @@ namespace AZ::ShaderCompiler void CodeReflection::DumpVariantList(const Options& options) const { m_out << GetVariantList(options); + m_out << "\n"; + AnalyzeOptionRanks(); } static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo) @@ -857,7 +859,8 @@ namespace AZ::ShaderCompiler for (auto& seenat : kindInfo->GetSeenats()) { assert(uid == seenat.m_referredDefinition); - // TODO: the assumption that intervals where distinct doesnt hold anymore now that we have unnamed scopes + // careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes) + // ok for now because AZSL/HLSL don't have lambdas auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value) { return value.first.properlyContains({key, key}); @@ -909,16 +912,9 @@ namespace AZ::ShaderCompiler uint32_t numOf32bitConst = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID); RootSigDesc rootSignature = BuildSignatureDescription(options, numOf32bitConst); - // prepare a lookup acceleration data structure for reverse mapping tokens to scopes. - MapOfBeginToSpanAndUid scopeStartToFunctionIntervals; - for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals) - { - if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. We need a set of disjoint intervals as an invariant for the next algorithm. - { - // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound) - scopeStartToFunctionIntervals[interval.a] = std::make_pair(interval, uid); - } - } + // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes. + // (truth: we need a set of disjoint intervals as an invariant for the following algorithm) + GenerateScopeStartToFunctionIntervalsReverseMap(); Json::Value srgRoot(Json::objectValue); // Order the reflection by SRG for convenience @@ -968,7 +964,7 @@ namespace AZ::ShaderCompiler else { set dependencyList; - DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals); srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {}); } } @@ -981,7 +977,7 @@ namespace AZ::ShaderCompiler for (auto& srgConstant : srgInfo->m_implicitStruct.GetMemberFields()) { allConstants.append({ srgConstant.GetNameLeaf() }); - DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, m_functionIntervals); } // variant fallback support if (srgInfo->m_shaderVariantFallback) @@ -992,7 +988,7 @@ namespace AZ::ShaderCompiler { if (varSub->CheckHasStorageFlag(StorageFlag::Option)) { - DiscoverTopLevelFunctionDependencies(varUid, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(varUid, dependencyList, m_functionIntervals); } } } @@ -1004,4 +1000,126 @@ namespace AZ::ShaderCompiler m_out << srgRoot; } + + void CodeReflection::AnalyzeOptionRanks() const + { + // make sure we have the function scope lookup cache ready + GenerateScopeStartToFunctionIntervalsReverseMap(); + // loop over variables + for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3()) + { + // only options + if (varInfo->CheckHasStorageFlag(StorageFlag::Option)) + { + int impactScore = 0; + // loop over appearances over the program + for (Seenat& ref : kindInfo->GetSeenats()) + { + // determine an impact score + impactScore += AnalyzeImpact(ref.m_where) // dependent code that may be skipped depending on the value of that ref + + 1; // by virtue of being mentioned (seenat), we count the reference as an access of cost 1. + } + m_out << "Option " << uid.GetName() << " has impact " << impactScore << "\n"; + } + } + } + + int CodeReflection::AnalyzeImpact(TokensLocation const& location) const + { + // find the node at `location`: + ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId); + // go up tree to meet a block node that has visitable depth: + // can be any of if/for/while/switch + // 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binary op->braces->cmp expr->cond->for + if (auto* whileNode = DeepParentAs(node->parent, 3)) + { + node = whileNode->embeddedStatement(); + } + else if (auto* ifNode = DeepParentAs(node->parent, 3)) + { + node = ifNode->embeddedStatement(); + } + else if (auto* forNode = DeepParentAs(node->parent, 4)) + { + node = forNode->embeddedStatement(); + } + else if (auto* switchNode = DeepParentAs(node->parent, 3)) + { + node = switchNode->switchBlock(); + } + int score = 0; + AnalyzeImpact(node, score); + return score; + } + + void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const + { + for (auto& c : astNode->children) + { + if (auto* callNode = As(c)) + { + // get function score in FunctionInfo if cached, compute it and store if not. + // to access the function symbol info we need the current scope, the function call name and perform a lookup. + auto intervalIter = FindInterval(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(), + [](ssize_t key, auto& value) + { + return value.first.properlyContains({key, key}); + }); + if (intervalIter != m_functionIntervals.cend()) + { + const IdentifierUID& encloser = intervalIter->second.second; + // lookup function at AST node `callNode` from scope `encloser` + if (auto* idExpr = As(callNode->Expr)) + { + UnqualifiedName funcName = ExtractNameFromIdExpression(idExpr->idExpression()); + m_ir->m_sema.ResolveOverload( + IdAndKind* symbolMeantUnderCallNode = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName); + auto* funcInfo = symbolMeantUnderCallNode->second.GetSubAs(); + if (funcInfo->m_costScore == -1) + { + funcInfo->m_costScore = 0; + using AstFDef = azslParser::HlslFunctionDefinitionContext; + AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), + funcInfo->m_costScore); // recurse and cache if not already done + } + scoreAccumulator += funcInfo->m_costScore; + } + // other cases forfeited for now, but that would at least be braces (f)() or MAE x.m() + } + else // no interval found + { + // function calls outside of function bodies can appear in an initializer: + // int g_a = MakeA(); // global init + // class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL) + // class D { void Method(int a_a = DefaultA()); // default parameter value + // in any case, extracting the scope is impossible with this system. + // we forfeit evaluation of a score + } + } + else if (auto* node = As(c)) + { + AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;" + } + if (auto* leaf = As(c)) + { + // determine cost by number of full expressions separated by semicolon + scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; + } + } + } + + void CodeReflection::GenerateScopeStartToFunctionIntervalsReverseMap() const + { + if (m_functionIntervals.empty()) + { + for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals) + { + if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. + { + // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound) + m_functionIntervals[interval.a] = std::make_pair(interval, uid); + } + } + } + } } diff --git a/src/AzslcReflection.h b/src/AzslcReflection.h index c3c0bc4..f8b8e38 100644 --- a/src/AzslcReflection.h +++ b/src/AzslcReflection.h @@ -11,6 +11,8 @@ namespace AZ::ShaderCompiler { + using MapOfBeginToSpanAndUid = map >; + struct CodeReflection : Backend { CodeReflection(IntermediateRepresentation* ir, TokenStream* tokens, std::ostream& out) @@ -45,6 +47,9 @@ namespace AZ::ShaderCompiler //! @param options user configuration parsed from command line void DumpResourceBindingDependencies(const Options& options) const; + //! Determine a heurisitcal global order between options in the program, using "impacted code size" static analysis. + void AnalyzeOptionRanks() const; + private: //! Builds member variable packing information and adds it to the membersContainer @@ -63,7 +68,6 @@ namespace AZ::ShaderCompiler bool BuildOMStruct(const ExtendedTypeInfo& returnTypeRef, string_view semanticOverride, Json::Value& jsonVal, int& semanticIndex) const; - using MapOfBeginToSpanAndUid = map >; //! Populate a list of functions where a symbol appear as potentially used //! @param uid The symbol to start the dependency analysis on //! @param output Any dependency symbol will be appended to this set @@ -75,6 +79,16 @@ namespace AZ::ShaderCompiler bool IsPotentialEntryPoint(const IdentifierUID& uid) const; + // Estimate a score proportional to how much code is "child" to the AST node at `location` + int AnalyzeImpact(TokensLocation const& location) const; + + // Recursive internal detail version + void AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const; + + //! Useful for static analysis on dependencies or option ranks + void GenerateScopeStartToFunctionIntervalsReverseMap() const; + mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call + std::ostream& m_out; }; } diff --git a/src/AzslcUtils.h b/src/AzslcUtils.h index 8886591..2b64bc6 100644 --- a/src/AzslcUtils.h +++ b/src/AzslcUtils.h @@ -958,15 +958,22 @@ namespace AZ::ShaderCompiler return UnqualifiedName{ctx->Name->getText()}; } - template - bool Is3ParentRuleOfType(antlr4::ParserRuleContext* ctx) + //! Get a pointer to the first parent that happens to be of type `SearchType` + //! with a limit depth of `maxDepth` parents to search through + template + SearchType DeepParentAs(tree::ParseTree* ctx, int maxDepth) { - if (ctx == nullptr || ctx->parent == nullptr || ctx->parent->parent == nullptr) // input canonicalization + if (auto* searchTypeNode = As(ctx)) { - return false; + return searchTypeNode; } - auto threeUp = ctx->parent->parent->parent; - return dynamic_cast(threeUp); + return maxDepth <= 0 || !ctx ? nullptr : DeepParentAs(ctx->parent, maxDepth - 1); + } + + template + bool Is3ParentRuleOfType(tree::ParseTree* ctx) + { + return DeepParentAs(ctx, 3); } // is def From 59ca8f14d3d569b92a748bc8370d318b225d8394 Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Sat, 25 Mar 2023 19:24:30 +0900 Subject: [PATCH 03/13] First working proof of concept of option rank cost static analyzer. Signed-off-by: Vivien Oddou --- src/AzslcReflection.cpp | 64 ++++++++++++++++++++++++++----- src/AzslcSemanticOrchestrator.cpp | 42 +++++++++++++------- src/AzslcSemanticOrchestrator.h | 5 ++- src/AzslcSymbolAggregator.cpp | 2 +- src/AzslcTypes.h | 55 ++++++++++++++++---------- src/AzslcUtils.h | 7 ++++ src/GenericUtils.h | 8 ++-- 7 files changed, 133 insertions(+), 50 deletions(-) diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index b62f100..c58e265 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -1072,17 +1072,63 @@ namespace AZ::ShaderCompiler if (auto* idExpr = As(callNode->Expr)) { UnqualifiedName funcName = ExtractNameFromIdExpression(idExpr->idExpression()); - m_ir->m_sema.ResolveOverload( - IdAndKind* symbolMeantUnderCallNode = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName); - auto* funcInfo = symbolMeantUnderCallNode->second.GetSubAs(); - if (funcInfo->m_costScore == -1) + IdAndKind* overload = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName); + if (!overload) // in case of function not found, we assume it's an intrinsic. { - funcInfo->m_costScore = 0; - using AstFDef = azslParser::HlslFunctionDefinitionContext; - AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), - funcInfo->m_costScore); // recurse and cache if not already done + if (IsOneOf(funcName, "CallShader", "TraceRay")) + { // non measurable but assumed high + scoreAccumulator += 50; + } + else if (IsOneOf(funcName, "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append")) + { // hardware locked memory ops, high weight + scoreAccumulator += 10; + } + else if (IsOneOf(funcName, "Sample", "Load")) + { // memory access is weighted in between + scoreAccumulator += 5; + } + else + { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1. + scoreAccumulator += 1; + } + } + else + { + azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode); + IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args); + IdentifierUID concrete; + if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet) + { // in case of strict selection failure, run a fuzzy select + size_t numArgs = NumArgs(callNode); + overload->second.GetSubAs()->AnyOf( + [&](IdentifierUID const& uid) + { + auto* concreteFcInfo = m_ir->GetSymbolSubAs(uid.GetName()); + size_t numParams = concreteFcInfo->GetParameters(true).size(); + if (numParams == numArgs) + { + concrete = uid; + return true; + } + return false; + } + ); + // if still not enough to get a fix, it might be an ill-formed input. prefer to forfeit + } + else + { + concrete = symbolMeantUnderCallNode->first; + } + auto* funcInfo = m_ir->GetSymbolSubAs(concrete.GetName()); + if (funcInfo->m_costScore == -1) + { + funcInfo->m_costScore = 0; + using AstFDef = azslParser::HlslFunctionDefinitionContext; + AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), + funcInfo->m_costScore); // recurse and cache if not already done + } + scoreAccumulator += funcInfo->m_costScore; } - scoreAccumulator += funcInfo->m_costScore; } // other cases forfeited for now, but that would at least be braces (f)() or MAE x.m() } diff --git a/src/AzslcSemanticOrchestrator.cpp b/src/AzslcSemanticOrchestrator.cpp index b17b061..4f7c08b 100644 --- a/src/AzslcSemanticOrchestrator.cpp +++ b/src/AzslcSemanticOrchestrator.cpp @@ -1166,7 +1166,9 @@ namespace AZ::ShaderCompiler As(ctx), As(ctx), As(ctx), - As(ctx)); + As(ctx), + As(ctx), + As(ctx)); } catch (AllNull&) { @@ -1175,6 +1177,24 @@ namespace AZ::ShaderCompiler } } + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const + { + // among all possibilities Plus|Minus|Not|Tilde|PlusPlus|MinusMinus + // only "Not" returns a bool, the rest is transparent and returns the same type as rhs + return ctx->prefixUnaryOperator()->start->getType() == azslLexer::Not ? MangleScalarType("bool") + : TypeofExpr(ctx->Expr); + } + + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const + { + return TypeofExpr(ctx->Expr); // in case of x++ or x-- the type is the type of x. + } + + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::BinaryExpressionContext* ctx) const + { + return TypeofExpr(ctx->Expr); + } + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const { return VisitFirstNonNull([this](auto* ctx) { return TypeofExpr(ctx); }, @@ -1303,29 +1323,23 @@ namespace AZ::ShaderCompiler QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::LiteralContext* ctx) const { - // verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array. - auto checkExistType = [](string_view scalarName){return std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), scalarName) != AZ::ShaderCompiler::Predefined::Scalar.end();}; // verifies that last or 1-before-last characters are a particular literal suffix. like in "56ul" auto hasSuffix = [](auto node, char s){return tolower(node->getText().back()) == s || tolower(Slice(node->getText(), -3, -2) == s);}; - auto checkAndReturn = [&](string_view typeName) - { - assert(checkExistType(typeName)); - return QualifiedName{"?" + string{typeName}}; - }; + if (ctx->True() || ctx->False()) { - return checkAndReturn("bool"); + return MangleScalarType("bool"); } else if (auto* literal = ctx->IntegerLiteral()) { - return hasSuffix(literal, 'u') ? checkAndReturn("uint") - : checkAndReturn("int"); + return hasSuffix(literal, 'u') ? MangleScalarType("uint") + : MangleScalarType("int"); } else if (auto* literal = ctx->FloatLiteral()) { - return hasSuffix(literal, 'h') ? checkAndReturn("half") - : hasSuffix(literal, 'l') ? checkAndReturn("double") - : checkAndReturn("float"); + return hasSuffix(literal, 'h') ? MangleScalarType("half") + : hasSuffix(literal, 'l') ? MangleScalarType("double") + : MangleScalarType("float"); } return {""}; } diff --git a/src/AzslcSemanticOrchestrator.h b/src/AzslcSemanticOrchestrator.h index a2aba91..6492c5a 100644 --- a/src/AzslcSemanticOrchestrator.h +++ b/src/AzslcSemanticOrchestrator.h @@ -222,6 +222,9 @@ namespace AZ::ShaderCompiler auto TypeofExpr(azslParser::LiteralExpressionContext* ctx) const -> QualifiedName; auto TypeofExpr(azslParser::LiteralContext* ctx) const -> QualifiedName; auto TypeofExpr(azslParser::CommaExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::BinaryExpressionContext* ctx) const -> QualifiedName; //! Parse the AST from a variable declaration and attempt to extract array dimensions integer constants [dim1][dim2]... //! Return: on success, otherwise @@ -330,7 +333,7 @@ namespace AZ::ShaderCompiler { auto typeId = LookupType(typeNameOrCtx, policy); auto tClass = GetTypeClass(typeId); - auto arithmetic = IsNonGenericArithmetic(tClass) ? CreateArithmeticTypeInfo(typeId.GetName()) : ArithmeticTypeInfo{}; // TODO: canonicalize generics + auto arithmetic = IsNonGenericArithmetic(tClass) ? CreateArithmeticTraits(typeId.GetName()) : ArithmeticTraits{}; // TODO: canonicalize generics return TypeRefInfo{typeId, arithmetic, tClass}; } diff --git a/src/AzslcSymbolAggregator.cpp b/src/AzslcSymbolAggregator.cpp index 45ac3b1..d1f78ef 100644 --- a/src/AzslcSymbolAggregator.cpp +++ b/src/AzslcSymbolAggregator.cpp @@ -20,7 +20,7 @@ namespace AZ::ShaderCompiler auto& [uid, kindInfo] = st.AddIdentifier(azirName, Kind::Type); // the kind is Type because all predefined are stored as such. auto& typeInfo = kindInfo.GetSubAfterInitAs(); auto typeClass = TypeClass::FromStr(bag.m_name); - auto arithmetic = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTypeInfo(azirName) : ArithmeticTypeInfo{}; // TODO: canonicalize generics + auto arithmetic = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTraits(azirName) : ArithmeticTraits{}; // TODO: canonicalize generics typeInfo = TypeRefInfo{uid, arithmetic, typeClass}; } } diff --git a/src/AzslcTypes.h b/src/AzslcTypes.h index 6ffd6ea..826b091 100644 --- a/src/AzslcTypes.h +++ b/src/AzslcTypes.h @@ -260,24 +260,36 @@ namespace AZ::ShaderCompiler return result; } + /// Verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array. + inline bool CheckExistScalarType(string_view scalarName) + { + return std::find(Predefined::Scalar.begin(), + Predefined::Scalar.end(), + scalarName) != Predefined::Scalar.end(); + }; + + /// Assert validity of the type string, and form a "?" tainted string to host a scalar type + inline QualifiedName MangleScalarType(string_view typeName) + { + assert(CheckExistScalarType(typeName)); + return QualifiedName{"?" + string{typeName}}; + }; + /// Rows and Cols (this is specific to shader languages to identify vector and matrix types) - struct ArithmeticTypeInfo + struct ArithmeticTraits { - void ResolveSize() + void ResolveBaseSizeAndRank() { - m_size = Packing::PackedSizeof(m_underlyingScalar); - } + m_baseSize = Packing::PackedSizeof(m_underlyingScalar); - /// Get the size of a single base element - const uint32_t GetBaseSize() const - { - return m_size; + + } - /// Get the size of a the element with regard to dimensions as well + /// Get the size of the whole type considering dimensions const uint32_t GetTotalSize() const { - return m_size * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); + return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); } /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well. @@ -310,10 +322,11 @@ namespace AZ::ShaderCompiler AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : ""; } - uint32_t m_size = 0; // In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct - uint32_t m_rows = 0; // 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix - uint32_t m_cols = 0; // 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix - int m_underlyingScalar = -1; // index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf. + uint32_t m_baseSize = 0; //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct + uint32_t m_rows = 0; //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix + uint32_t m_cols = 0; //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix + int m_conversionRank = 0; //< Used in conversions and promotions + int m_underlyingScalar = -1; //< Index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf. }; //! TypeRefInfo holds resolved immutable information of a core type (the `matrix2x2` in `column_major matrix2x2 a[3];`) @@ -323,7 +336,7 @@ namespace AZ::ShaderCompiler struct TypeRefInfo { TypeRefInfo() = default; - TypeRefInfo(IdentifierUID typeId, const ArithmeticTypeInfo& fundamentalInfo, TypeClass typeClass) + TypeRefInfo(IdentifierUID typeId, const ArithmeticTraits& fundamentalInfo, TypeClass typeClass) : m_arithmeticInfo{fundamentalInfo}, m_typeClass{typeClass}, m_typeId{typeId} @@ -359,19 +372,19 @@ namespace AZ::ShaderCompiler return !operator==(lhs,rhs); } - IdentifierUID m_typeId; - TypeClass m_typeClass; - ArithmeticTypeInfo m_arithmeticInfo; + IdentifierUID m_typeId; + TypeClass m_typeClass; + ArithmeticTraits m_arithmeticInfo; }; //! Run a syntactic analysis of an arithmetic type name and extract info on its composition - inline ArithmeticTypeInfo CreateArithmeticTypeInfo(QualifiedName a_typeName) + inline ArithmeticTraits CreateArithmeticTraits(QualifiedName a_typeName) { assert(IsArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) )); // no need to call this function if you don't have a fundamental (non void) assert(!IsGenericArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) )); // ↑ fatal aspect. The input needs to be canonicalized earlier to minimize this function's complexity. - ArithmeticTypeInfo toReturn; + ArithmeticTraits toReturn; string typeName = UnMangle(a_typeName); size_t baseTypeLen = typeName.length(); @@ -422,7 +435,7 @@ namespace AZ::ShaderCompiler auto it = ::std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), baseType); assert(it != AZ::ShaderCompiler::Predefined::Scalar.end()); // baseType must exist in the Scalar bag by program invariant. toReturn.m_underlyingScalar = static_cast( std::distance(AZ::ShaderCompiler::Predefined::Scalar.begin(), it) ); - toReturn.ResolveSize(); + toReturn.ResolveBaseSizeAndRank(); return toReturn; } diff --git a/src/AzslcUtils.h b/src/AzslcUtils.h index 2b64bc6..e07f8c1 100644 --- a/src/AzslcUtils.h +++ b/src/AzslcUtils.h @@ -1047,6 +1047,13 @@ namespace AZ::ShaderCompiler return found ? found->argumentList() : nullptr; } + //! access the argument count at a function call site (from the AST) + inline size_t NumArgs(azslParser::FunctionCallExpressionContext* callCtx) + { + azslParser::ArgumentsContext* argsNode = callCtx->argumentList()->arguments(); + return argsNode ? argsNode->expression().size() : 0; + } + //! try to find a specific context type that this context would be a child of. template inline LookedUp* ExtractSpecificParent(antlr4::ParserRuleContext* ctx) diff --git a/src/GenericUtils.h b/src/GenericUtils.h index 378a214..5d379ee 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -379,14 +379,14 @@ namespace AZ // Is-One-Of will check if a variable is equal to any of the values listed on the other parameters. // Example: IsOneOf(variableKind, Function, Enumeration) is short for: variableKind == Function || variableKind == Enumeration. // 2 arguments count: recursion terminal overload. - template - bool IsOneOf(T value, T tocheck) + template + bool IsOneOf(T value, U tocheck) { return value == tocheck; } // Any argument count version - template - bool IsOneOf(T value, T khead, Args... tail) + template + bool IsOneOf(T value, U khead, Args... tail) { return value == khead || IsOneOf(value, tail...); } From d3d742f7ffdba6b52f8b5a4e0ee3e27ed05ddbbf Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 27 Mar 2023 22:54:50 +0900 Subject: [PATCH 04/13] First working order of binary expression type detection support. Still incomplete because of vector/matrix combination cases. Signed-off-by: Vivien Oddou --- src/AzslcIntermediateRepresentation.cpp | 2 +- src/AzslcKindInfo.h | 2 +- src/AzslcMain.cpp | 3 +- src/AzslcReflection.cpp | 2 +- src/AzslcSemanticOrchestrator.cpp | 22 +++++++++++-- src/AzslcTypes.h | 43 +++++++++++++++++++++++-- src/PadToAttributeMutator.cpp | 2 +- 7 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/AzslcIntermediateRepresentation.cpp b/src/AzslcIntermediateRepresentation.cpp index 89a6c7f..936bc1d 100644 --- a/src/AzslcIntermediateRepresentation.cpp +++ b/src/AzslcIntermediateRepresentation.cpp @@ -518,7 +518,7 @@ namespace AZ::ShaderCompiler if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } nextMemberStartingOffset = Packing::PackNextChunk(layoutPacking, size, startAt); diff --git a/src/AzslcKindInfo.h b/src/AzslcKindInfo.h index 35ae46a..ae09224 100644 --- a/src/AzslcKindInfo.h +++ b/src/AzslcKindInfo.h @@ -248,7 +248,7 @@ namespace AZ::ShaderCompiler //! Get the size of a single element, ignoring array dimensions const uint32_t GetSingleElementSize(Packing::Layout layout, bool defaultRowMajor) const { - auto baseSize = m_coreType.m_arithmeticInfo.GetBaseSize(); + auto baseSize = m_coreType.m_arithmeticInfo.m_baseSize; bool isRowMajor = (m_mtxMajor == Packing::MatrixMajor::RowMajor || (m_mtxMajor == Packing::MatrixMajor::Default && defaultRowMajor)); auto rows = m_coreType.m_arithmeticInfo.m_rows; diff --git a/src/AzslcMain.cpp b/src/AzslcMain.cpp index 79ec6f9..0159d7c 100644 --- a/src/AzslcMain.cpp +++ b/src/AzslcMain.cpp @@ -23,8 +23,7 @@ namespace StdFs = std::filesystem; // For large features or milestones. Minor version allows for breaking changes. Existing tests can change. #define AZSLC_MINOR "8" // last change: introduction of class inheritance // For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change. -#define AZSLC_REVISION "17" // last change: automatic option ranks - // "16" change: fixup runtime error with redundant function declarations +#define AZSLC_REVISION "18" // last change: automatic option ranks namespace AZ::ShaderCompiler { diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index c58e265..302e16c 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -589,7 +589,7 @@ namespace AZ::ShaderCompiler else if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = m_ir->GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } offset = Packing::PackNextChunk(layoutPacking, size, startAt); diff --git a/src/AzslcSemanticOrchestrator.cpp b/src/AzslcSemanticOrchestrator.cpp index 4f7c08b..70e0881 100644 --- a/src/AzslcSemanticOrchestrator.cpp +++ b/src/AzslcSemanticOrchestrator.cpp @@ -1168,7 +1168,8 @@ namespace AZ::ShaderCompiler As(ctx), As(ctx), As(ctx), - As(ctx)); + As(ctx), + As(ctx)); } catch (AllNull&) { @@ -1192,7 +1193,24 @@ namespace AZ::ShaderCompiler QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::BinaryExpressionContext* ctx) const { - return TypeofExpr(ctx->Expr); + using lex = azslLexer; + auto boolResultOperators = {lex::Less, lex::Greater, lex::LessEqual, lex::GreaterEqual, lex::NotEqual, lex::AndAnd, lex::OrOr}; + if (IsIn(ctx->binaryOperator()->start->getType(), boolResultOperators)) + { + return MangleScalarType("bool"); + } + QualifiedName lhs = TypeofExpr(ctx->Left); + QualifiedName rhs = TypeofExpr(ctx->Right); + TypeRefInfo typeInfoLhs = CreateTypeRefInfo(UnqualifiedNameView{lhs}); // We tolerate a cast here because GetTypeRefInfo was designed to lookup types, but TypeofExpr has already looked up the type. + TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs}); + if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty()) + { // Case that shouldn't work in AZSL yet (but may work in HLSL2021) + // -> UDT operator. need operator overloading. We assume type is type of left expression. + return lhs; + } + // final logic in case of arithmetic type class: integer/float promotion. + return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ? + lhs : rhs; } QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const diff --git a/src/AzslcTypes.h b/src/AzslcTypes.h index 826b091..450f311 100644 --- a/src/AzslcTypes.h +++ b/src/AzslcTypes.h @@ -281,9 +281,46 @@ namespace AZ::ShaderCompiler void ResolveBaseSizeAndRank() { m_baseSize = Packing::PackedSizeof(m_underlyingScalar); - - - + // establish the conversion rank: + auto getIndex = [](string_view s) -> int + { + auto const& Scalars = Predefined::Scalar; + return ::std::distance(Scalars.begin(), + ::std::find(Scalars.begin(), Scalars.end(), s)); + }; + // According to https://en.cppreference.com/w/cpp/language/usual_arithmetic_conversions + // - No two signed have the same rank (even if same siezeof) + // - rank of unsigned = rank of corresponding signed + // - "standard" is > "extended" of same sizeof + // - The rank of bool is the smallest + // That said, we will take inspiration from ASTContext::getIntegerRank of clang + // (which does not respect C++ visibly, since it takes bool's size into account, or has many equivalent ranks, in violation of rule 1) + static const unordered_map subranks = + { + {getIndex("bool"), 1}, + {getIndex("int16_t"), 2}, + {getIndex("uint16_t"), 3}, // unsigned wins in case of subrank draw, according to arithmetic conversion rules + {getIndex("int"), 4}, + {getIndex("uint"), 5}, + {getIndex("dword"), 6}, + {getIndex("int32_t"), 7}, + {getIndex("uint32_t"), 8}, + {getIndex("int64_t"), 9}, + {getIndex("uint64_t"), 10}, + {getIndex("half"), 11 << 5}, // floats win all conversions, even halfs + {getIndex("float"), 12 << 5}, + {getIndex("double"), 13 << 5}, + }; + // `basesize` getter, but 1 for bool: (physical size of extern bool is considered 32bits in HLSL) + auto getRankSizeof = [&](int scalarId) + { + assert(string_view{"bool"} == Predefined::Scalar[0]); // verify that 0 is the hard index of bool. + bool isBool = scalarId == 0; + return isBool ? 1 : m_baseSize; + }; + // The shift method is taken from clang, I suppose it's a multi-parameter order cramed into bits. + // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subank space. + m_conversionRank = (getRankSizeof(m_underlyingScalar) << 4) + subranks.at(m_underlyingScalar); } /// Get the size of the whole type considering dimensions diff --git a/src/PadToAttributeMutator.cpp b/src/PadToAttributeMutator.cpp index 1697803..4aad68a 100644 --- a/src/PadToAttributeMutator.cpp +++ b/src/PadToAttributeMutator.cpp @@ -354,7 +354,7 @@ namespace AZ::ShaderCompiler else if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = m_ir.GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } offset = Packing::PackNextChunk(layoutPacking, size, startAt); From 31b44baad27a5e2410bdf46b370d710e6f60072e Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Wed, 5 Apr 2023 01:06:00 +0900 Subject: [PATCH 05/13] WIP: satisfactory working order for vector and matrix arithmetic type deduction in binary operations Signed-off-by: Vivien Oddou --- src/AzslcSemanticOrchestrator.cpp | 31 +++++++++++++++-- src/AzslcTypes.h | 58 ++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/AzslcSemanticOrchestrator.cpp b/src/AzslcSemanticOrchestrator.cpp index 70e0881..f4568e7 100644 --- a/src/AzslcSemanticOrchestrator.cpp +++ b/src/AzslcSemanticOrchestrator.cpp @@ -1205,12 +1205,37 @@ namespace AZ::ShaderCompiler TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs}); if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty()) { // Case that shouldn't work in AZSL yet (but may work in HLSL2021) - // -> UDT operator. need operator overloading. We assume type is type of left expression. + // -> UDT operator (would need support of operator overloading). + // We assume type is type of left expression. + // (what we really need to do is go get the return type of the overloaded operator) return lhs; } + // After this is, both sides are arithmetic types (scalar, vector, matrix). + // matrix op vector is a forbidden case, + // e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar) + // It will do component to component op and return more or less the same type. + // In case of dimension differences, it will truncate to smaller type + // e.g float2 + float3 results in float2 with .z lost in implicit cast + // same for float2x3 * float2x2 (results in float2x2) + // We assume that for * + - / % ^ | & << >> + bool lhsIsVecMat = typeInfoLhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix); + bool rhsIsVecMat = typeInfoRhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix); + if (lhsIsVecMat !=/*xor*/ rhsIsVecMat) + { + auto scalarOperand = lhsIsVecMat ? typeInfoRhs : typeInfoLhs; + auto vecmatOperand = lhsIsVecMat ? typeInfoLhs : typeInfoRhs; + // typeof(vecmat op scalar)->promoted(vecmat) + return vecmatOperand.m_arithmeticInfo.PromoteTruncateWith(scalarOperand.m_arithmeticInfo).GenTypeId(); + } + //else if (lhsIsVecMat && rhsIsVecMat) + { + // typeof(vecmat op vecmat)->promoted(truncated(vecmat)) + return typeInfoLhs.m_arithmeticInfo.PromoteTruncateWith(typeInfoRhs.m_arithmeticInfo).GenTypeId(); + } + // case left: both sides are scalar. // final logic in case of arithmetic type class: integer/float promotion. - return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ? - lhs : rhs; + //return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ? + // lhs : rhs; } QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const diff --git a/src/AzslcTypes.h b/src/AzslcTypes.h index 450f311..890f50f 100644 --- a/src/AzslcTypes.h +++ b/src/AzslcTypes.h @@ -275,7 +275,7 @@ namespace AZ::ShaderCompiler return QualifiedName{"?" + string{typeName}}; }; - /// Rows and Cols (this is specific to shader languages to identify vector and matrix types) + //! Holds arithmetic-type-class information in small pieces (row, cols, base, size, rank...) struct ArithmeticTraits { void ResolveBaseSizeAndRank() @@ -294,7 +294,6 @@ namespace AZ::ShaderCompiler // - "standard" is > "extended" of same sizeof // - The rank of bool is the smallest // That said, we will take inspiration from ASTContext::getIntegerRank of clang - // (which does not respect C++ visibly, since it takes bool's size into account, or has many equivalent ranks, in violation of rule 1) static const unordered_map subranks = { {getIndex("bool"), 1}, @@ -319,17 +318,17 @@ namespace AZ::ShaderCompiler return isBool ? 1 : m_baseSize; }; // The shift method is taken from clang, I suppose it's a multi-parameter order cramed into bits. - // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subank space. + // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subrank space. m_conversionRank = (getRankSizeof(m_underlyingScalar) << 4) + subranks.at(m_underlyingScalar); } - /// Get the size of the whole type considering dimensions + //! Get the size of the whole type considering dimensions const uint32_t GetTotalSize() const { return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); } - /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well. + //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well. const bool IsVector() const { // This treats special cases like 2x1, 3x1 and 4x1 as vectors @@ -337,7 +336,7 @@ namespace AZ::ShaderCompiler return (m_cols == 1 && m_rows > 1) || (m_cols > 1 && m_rows == 0); } - /// True if the type is a matrix type. If it's a matrix type it cannot be a vector as well. + //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well. const bool IsMatrix() const { // This fails special cases like 2x1, 3x1 and 4x1, @@ -346,19 +345,62 @@ namespace AZ::ShaderCompiler return m_rows > 0 && m_cols > 1; } - /// If initialized as a fundamental -> not empty. + //! If initialized as a fundamental -> not empty. const bool IsEmpty() const { return m_underlyingScalar == -1; } - // for pretty print + //! For pretty print string UnderlyingScalarToStr() const { return m_underlyingScalar >= 0 && m_underlyingScalar < AZ::ShaderCompiler::Predefined::Scalar.size() ? AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : ""; } + //! Create a canonicalized mangled name that should represent the identity of this arithmetic type. + QualifiedName GenTypeId() const + { + if (IsMatrix()) + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + ToString(m_rows) + "x" + ToString(m_cols)}; + } + else if (IsVector()) + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + (m_rows > 0 ? ToString(m_rows) : ToString(m_cols))}; + } + else + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr())}; + } + } + + //! Create a new arithmetic traits promoted by necessity (through a binary operation usually) + //! of compatibility with a second operand of arithmetic typeclass. + //! For example: type(half{} + int{})->half + //! type(float3x3{} * double{})->double3x3 + //! And with columns & rows truncated to the smallest operand, + //! as part of the implicit necessary cast for operation compatibility. + ArithmeticTraits PromoteTruncateWith(const ArithmeticTraits& secondOperand) const + { + ArithmeticTraits copy{*this}; + // The higher ranking underlying wins independently of global object size + if (secondOperand.m_conversionRank > m_conversionRank) + { + copy.m_underlyingScalar = secondOperand.m_underlyingScalar; + } + if (secondOperand.m_rows > 0 && m_rows > 0) + { + copy.m_rows = std::min(m_rows, secondOperand.m_rows); + } + if (secondOperand.m_cols > 0 && m_cols > 0) + { + copy.m_cols = std::min(m_cols, secondOperand.m_cols); + } + copy.ResolveBaseSizeAndRank(); + return copy; + } + uint32_t m_baseSize = 0; //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct uint32_t m_rows = 0; //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix uint32_t m_cols = 0; //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix From 878c89c9af4b402bd7149239835c294c441383b5 Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Wed, 5 Apr 2023 20:30:22 +0900 Subject: [PATCH 06/13] Cleaner implementation of binary operation type detection. Principle: diminish code paths at maximum by reducing the conditions checked to an absolute minimalist level. Signed-off-by: Vivien Oddou --- src/AzslcSemanticOrchestrator.cpp | 40 +++++------------ src/AzslcTypes.h | 71 +++++++++++++++++-------------- src/GenericUtils.h | 9 ++++ src/MetaUtils.h | 4 ++ 4 files changed, 64 insertions(+), 60 deletions(-) diff --git a/src/AzslcSemanticOrchestrator.cpp b/src/AzslcSemanticOrchestrator.cpp index f4568e7..67ccc13 100644 --- a/src/AzslcSemanticOrchestrator.cpp +++ b/src/AzslcSemanticOrchestrator.cpp @@ -1205,37 +1205,19 @@ namespace AZ::ShaderCompiler TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs}); if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty()) { // Case that shouldn't work in AZSL yet (but may work in HLSL2021) - // -> UDT operator (would need support of operator overloading). - // We assume type is type of left expression. - // (what we really need to do is go get the return type of the overloaded operator) + // -> UDT operator (would need support of operator overloading). + // We arbitrarily assume a result "type of left expression". + // (what we would really need to do is go get the return type of the overloaded operator) return lhs; - } - // After this is, both sides are arithmetic types (scalar, vector, matrix). - // matrix op vector is a forbidden case, - // e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar) - // It will do component to component op and return more or less the same type. - // In case of dimension differences, it will truncate to smaller type - // e.g float2 + float3 results in float2 with .z lost in implicit cast + } // After this `if`, both sides are arithmetic types (scalar, vector, matrix). + // "matrix op vector" (or commutated) is a forbidden case, + // e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar) + // It will do piecewise `op` and return more or less the same type as the operands. + // In case of dimension differences, it will truncate to the smaller type + // e.g float2 + float3 results in float2 with .z lost in implicit cast // same for float2x3 * float2x2 (results in float2x2) - // We assume that for * + - / % ^ | & << >> - bool lhsIsVecMat = typeInfoLhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix); - bool rhsIsVecMat = typeInfoRhs.m_typeClass.IsOneOf(TypeClass::Vector, TypeClass::Matrix); - if (lhsIsVecMat !=/*xor*/ rhsIsVecMat) - { - auto scalarOperand = lhsIsVecMat ? typeInfoRhs : typeInfoLhs; - auto vecmatOperand = lhsIsVecMat ? typeInfoLhs : typeInfoRhs; - // typeof(vecmat op scalar)->promoted(vecmat) - return vecmatOperand.m_arithmeticInfo.PromoteTruncateWith(scalarOperand.m_arithmeticInfo).GenTypeId(); - } - //else if (lhsIsVecMat && rhsIsVecMat) - { - // typeof(vecmat op vecmat)->promoted(truncated(vecmat)) - return typeInfoLhs.m_arithmeticInfo.PromoteTruncateWith(typeInfoRhs.m_arithmeticInfo).GenTypeId(); - } - // case left: both sides are scalar. - // final logic in case of arithmetic type class: integer/float promotion. - //return typeInfoLhs.m_arithmeticInfo.m_conversionRank > typeInfoRhs.m_arithmeticInfo.m_conversionRank ? - // lhs : rhs; + // We assume that for all non bool ops: * + - / % ^ | & << >> + return PromoteTruncateWith({typeInfoLhs.m_arithmeticInfo, typeInfoRhs.m_arithmeticInfo}).GenTypeId(); } QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const diff --git a/src/AzslcTypes.h b/src/AzslcTypes.h index 890f50f..a3503cf 100644 --- a/src/AzslcTypes.h +++ b/src/AzslcTypes.h @@ -323,13 +323,13 @@ namespace AZ::ShaderCompiler } //! Get the size of the whole type considering dimensions - const uint32_t GetTotalSize() const + uint32_t GetTotalSize() const { return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); } //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well. - const bool IsVector() const + bool IsVector() const { // This treats special cases like 2x1, 3x1 and 4x1 as vectors // The behavior is consistent with dxc packing rules @@ -337,7 +337,7 @@ namespace AZ::ShaderCompiler } //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well. - const bool IsMatrix() const + bool IsMatrix() const { // This fails special cases like 2x1, 3x1 and 4x1, // but allows cases like 1x2, 1x3 and 1x4. @@ -345,8 +345,13 @@ namespace AZ::ShaderCompiler return m_rows > 0 && m_cols > 1; } - //! If initialized as a fundamental -> not empty. - const bool IsEmpty() const + bool IsScalar() const + { + return m_rows <= 1 && m_cols <= 1; // float, float1, float1x1 are scalars. + } + + //! Non-created state + bool IsEmpty() const { return m_underlyingScalar == -1; } @@ -375,32 +380,6 @@ namespace AZ::ShaderCompiler } } - //! Create a new arithmetic traits promoted by necessity (through a binary operation usually) - //! of compatibility with a second operand of arithmetic typeclass. - //! For example: type(half{} + int{})->half - //! type(float3x3{} * double{})->double3x3 - //! And with columns & rows truncated to the smallest operand, - //! as part of the implicit necessary cast for operation compatibility. - ArithmeticTraits PromoteTruncateWith(const ArithmeticTraits& secondOperand) const - { - ArithmeticTraits copy{*this}; - // The higher ranking underlying wins independently of global object size - if (secondOperand.m_conversionRank > m_conversionRank) - { - copy.m_underlyingScalar = secondOperand.m_underlyingScalar; - } - if (secondOperand.m_rows > 0 && m_rows > 0) - { - copy.m_rows = std::min(m_rows, secondOperand.m_rows); - } - if (secondOperand.m_cols > 0 && m_cols > 0) - { - copy.m_cols = std::min(m_cols, secondOperand.m_cols); - } - copy.ResolveBaseSizeAndRank(); - return copy; - } - uint32_t m_baseSize = 0; //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct uint32_t m_rows = 0; //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix uint32_t m_cols = 0; //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix @@ -518,6 +497,36 @@ namespace AZ::ShaderCompiler return toReturn; } + //! Create a new arithmetic traits promoted by necessity (through a binary operation usually) + //! of compatibility with a second operand of arithmetic typeclass. + //! For example: type(half{} + int{})->half + //! type(float3x3{} * double{})->double3x3 + //! And with columns & rows truncated to the smallest operand, + //! as part of the implicit necessary cast for operation compatibility. + inline ArithmeticTraits PromoteTruncateWith(Pair operands) + { + auto [_1, _2] = operands; + // put the scalar last (will be useful later if there is only one scalar operand) + SwapIf(_1, _2, _1.IsScalar()); + // now, let's construct the result in _1 + + if (_2.m_conversionRank > _1.m_conversionRank) // the higher ranking underlying wins; independently of full object size + { + _1.m_underlyingScalar = _2.m_underlyingScalar; + } + // cases: scalar-scalar : no dim change + // vecmat-scalar : no dim change, since result is dim(vecmat) which is in _1 + // scalar-vecmat : impossible (sorted by swap above) + // vecmat-vecmat : min(_1,_2) + if (!_1.IsScalar() && !_2.IsScalar()) + { + _1.m_rows = std::min(_1.m_rows, _2.m_rows); + _1.m_cols = std::min(_1.m_cols, _2.m_cols); + } + _1.ResolveBaseSizeAndRank(); + return _1; + } + MAKE_REFLECTABLE_ENUM(RootParamType, SRV, // t UAV, // u diff --git a/src/GenericUtils.h b/src/GenericUtils.h index 5d379ee..9589563 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -624,6 +624,15 @@ namespace AZ RemoveDuplicatesKeepOrder(lhs); } + //! Conditional swap algorithm + template + void SwapIf(T&& a, T&& b, bool condition) + { + if (condition) + { + std::swap(std::forward(a), std::forward(b)); + } + } } #ifndef NDEBUG diff --git a/src/MetaUtils.h b/src/MetaUtils.h index f946f65..3a7a69b 100644 --- a/src/MetaUtils.h +++ b/src/MetaUtils.h @@ -366,6 +366,10 @@ namespace AZ template class Op, class... Args> using DetectedOr_t = typename DetectedOr::type; + + //! define a Pair typealias that has two same T, without the need to repeat yourself + template + using Pair = std::pair; } #ifndef NDEBUG From 661d01ffb6d9e4fcb9b77ba6c0008f3840d45fdf Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 10 Apr 2023 01:57:46 +0900 Subject: [PATCH 07/13] Refactor the option rank analysis to make it easier to tread with named function for each action. + prepare a fix for the type analysis of the "method call solver". This fail is due to the fact that we are past the semantic analysis, so we don't have proper scope tracking. The lookup cannot work if we don't provide a starting scope. That starting scope is reconstructed artificially using the scopes/token map collection. Unfortunately, the fast lookup is now using an intermediate map that filters by function only. If we don't include the unnamed blocks, Lookup will systematically fail for any object within curly brace or an if block, for block etc. To solve that problem, I prepared a "non disjointed" interval query system. It's an unfortunate change from Log(n) by query to O(N) by query though. Also we have a memory fest since these are node containers. We might want to consider Howard Hinnant stack allocator soon after. Signed-off-by: Vivien Oddou --- src/AzslcReflection.cpp | 217 +++++++++++++++++------------ src/AzslcReflection.h | 3 + src/GenericUtils.h | 202 ++++++++++++++++++++------- tests/Semantic/typeof-keyword.azsl | 116 ++++++++++++++- 4 files changed, 396 insertions(+), 142 deletions(-) diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index 302e16c..c95a0e0 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -861,10 +861,10 @@ namespace AZ::ShaderCompiler assert(uid == seenat.m_referredDefinition); // careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes) // ok for now because AZSL/HLSL don't have lambdas - auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value) - { - return value.first.properlyContains({key, key}); - }); + auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value) + { + return value.first.properlyContains({key, key}); + }); if (intervalIter != scopes.cend()) { const IdentifierUID& encloser = intervalIter->second.second; @@ -1001,6 +1001,47 @@ namespace AZ::ShaderCompiler m_out << srgRoot; } + // Helper routine for option rank analysis + static int GuesstimateIntrinsicFunctionCost(string_view funcName) + { + if (IsOneOf(funcName, "CallShader", "TraceRay")) + { // non measurable but assumed high + return 100; + } + else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append")) + { // memory access, locked or not, will have high latency + return 10; + } + else + { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1. + return 1; + } + } + + // Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting. + // The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains + // strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any. + static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir, + azslParser::FunctionCallExpressionContext* callNode, + KindInfo& overload) + { + IdentifierUID concrete; + size_t numArgs = NumArgs(callNode); + overload.GetSubAs()->AnyOf( + [&](IdentifierUID const& uid) + { + auto* concreteFcInfo = ir->GetSymbolSubAs(uid.GetName()); + size_t numParams = concreteFcInfo->GetParameters(true).size(); + if (numParams == numArgs) + { + concrete = uid; // we write the result through reference capture (not clean but convenient) + return true; + } + return false; + }); + return concrete; + } + void CodeReflection::AnalyzeOptionRanks() const { // make sure we have the function scope lookup cache ready @@ -1030,7 +1071,7 @@ namespace AZ::ShaderCompiler ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId); // go up tree to meet a block node that has visitable depth: // can be any of if/for/while/switch - // 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binary op->braces->cmp expr->cond->for + // 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for if (auto* whileNode = DeepParentAs(node->parent, 3)) { node = whileNode->embeddedStatement(); @@ -1058,89 +1099,8 @@ namespace AZ::ShaderCompiler { if (auto* callNode = As(c)) { - // get function score in FunctionInfo if cached, compute it and store if not. - // to access the function symbol info we need the current scope, the function call name and perform a lookup. - auto intervalIter = FindInterval(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(), - [](ssize_t key, auto& value) - { - return value.first.properlyContains({key, key}); - }); - if (intervalIter != m_functionIntervals.cend()) - { - const IdentifierUID& encloser = intervalIter->second.second; - // lookup function at AST node `callNode` from scope `encloser` - if (auto* idExpr = As(callNode->Expr)) - { - UnqualifiedName funcName = ExtractNameFromIdExpression(idExpr->idExpression()); - IdAndKind* overload = m_ir->m_symbols.LookupSymbol(encloser.GetName(), funcName); - if (!overload) // in case of function not found, we assume it's an intrinsic. - { - if (IsOneOf(funcName, "CallShader", "TraceRay")) - { // non measurable but assumed high - scoreAccumulator += 50; - } - else if (IsOneOf(funcName, "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append")) - { // hardware locked memory ops, high weight - scoreAccumulator += 10; - } - else if (IsOneOf(funcName, "Sample", "Load")) - { // memory access is weighted in between - scoreAccumulator += 5; - } - else - { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1. - scoreAccumulator += 1; - } - } - else - { - azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode); - IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args); - IdentifierUID concrete; - if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet) - { // in case of strict selection failure, run a fuzzy select - size_t numArgs = NumArgs(callNode); - overload->second.GetSubAs()->AnyOf( - [&](IdentifierUID const& uid) - { - auto* concreteFcInfo = m_ir->GetSymbolSubAs(uid.GetName()); - size_t numParams = concreteFcInfo->GetParameters(true).size(); - if (numParams == numArgs) - { - concrete = uid; - return true; - } - return false; - } - ); - // if still not enough to get a fix, it might be an ill-formed input. prefer to forfeit - } - else - { - concrete = symbolMeantUnderCallNode->first; - } - auto* funcInfo = m_ir->GetSymbolSubAs(concrete.GetName()); - if (funcInfo->m_costScore == -1) - { - funcInfo->m_costScore = 0; - using AstFDef = azslParser::HlslFunctionDefinitionContext; - AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), - funcInfo->m_costScore); // recurse and cache if not already done - } - scoreAccumulator += funcInfo->m_costScore; - } - } - // other cases forfeited for now, but that would at least be braces (f)() or MAE x.m() - } - else // no interval found - { - // function calls outside of function bodies can appear in an initializer: - // int g_a = MakeA(); // global init - // class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL) - // class D { void Method(int a_a = DefaultA()); // default parameter value - // in any case, extracting the scope is impossible with this system. - // we forfeit evaluation of a score - } + // branch into an overload specialized for function lookup: + AnalyzeImpact(callNode, scoreAccumulator); } else if (auto* node = As(c)) { @@ -1149,8 +1109,87 @@ namespace AZ::ShaderCompiler if (auto* leaf = As(c)) { // determine cost by number of full expressions separated by semicolon - scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; + scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; // bool as 0 or 1 trick + } + } + } + + void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const + { + // to access the function symbol info we need the current scope, the function call name and perform a lookup. + + // figure out the scope at this token. + // theoretically should be something in the like of the body of another function, + // or an anonymous block within another function. + auto intervalIter = FindIntervalInDisjointSet(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(), + [](ssize_t key, auto& value) + { + return value.first.properlyContains({key, key}); + }); + if (intervalIter != m_functionIntervals.cend()) + { + IdentifierUID encloser = intervalIter->second.second; + + // Because we are past the end of the semantic analysis, + // the scope tracker is registering the last seen scope (surely "/"). + // This is a stateful side-effect system unfortunately, and since we'll call + // some feature of the semantic orchestrator (like TypeofExpr) we need to hack + // the scope tracker: + m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName(); + m_ir->m_sema.m_scope->UpdateCurScopeUID(); + + QualifiedName startupLookupScope = encloser.GetName(); + UnqualifiedName funcName; + if (auto* idExpr = As(callNode->Expr)) + { + funcName = ExtractNameFromIdExpression(idExpr->idExpression()); + } + else if (auto* maeExpr = As(callNode->Expr)) + { + startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr); + } + IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName); + if (!overload) // in case of function not found, we assume it's an intrinsic. + { + scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName); } + else + { + azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode); + IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args); + IdentifierUID concrete; + if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet) + { // in case of strict selection failure, run a fuzzy select + concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second); + // if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit + } + else + { + concrete = symbolMeantUnderCallNode->first; + } + + if (auto* funcInfo = m_ir->GetSymbolSubAs(concrete.GetName())) + { + if (funcInfo->m_costScore == -1) // cost not yet discovered for this function + { + funcInfo->m_costScore = 0; + using AstFDef = azslParser::HlslFunctionDefinitionContext; + AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), + funcInfo->m_costScore); // recurse and cache + } + scoreAccumulator += funcInfo->m_costScore; + } + } + // other cases forfeited for now, but that would at least include things like eg braces (f)() + } + else // no interval found + { + // function calls outside of function bodies can appear in an initializer: + // int g_a = MakeA(); // global init + // class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL) + // class D { void Method(int a_a = DefaultA()); // default parameter value + // in any case, extracting the scope is impossible with this system. + // we forfeit evaluation of a score } } diff --git a/src/AzslcReflection.h b/src/AzslcReflection.h index f8b8e38..c9c3ddf 100644 --- a/src/AzslcReflection.h +++ b/src/AzslcReflection.h @@ -85,6 +85,9 @@ namespace AZ::ShaderCompiler // Recursive internal detail version void AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const; + // Function-call specific + void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const; + //! Useful for static analysis on dependencies or option ranks void GenerateScopeStartToFunctionIntervalsReverseMap() const; mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call diff --git a/src/GenericUtils.h b/src/GenericUtils.h index 9589563..a63c656 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -24,9 +24,9 @@ namespace AZ { using runtime_error::runtime_error; }; - // Type-heterogeneity-preserving multi pointer object single visitor. - // Returns whatever the passed functor would. - // Throws if all passed objects are null. + //! Type-heterogeneity-preserving multi pointer object single visitor. + //! Returns whatever the passed functor would. + //! Throws if all passed objects are null. template invoke_result_t VisitFirstNonNull(Lambda functor, T* object) noexcept(false) { @@ -50,9 +50,9 @@ namespace AZ } } - // Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics. - // what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]` - // http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/ + //! Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics. + //! what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]` + //! http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/ inline constexpr string_view Slice(const string_view& in, int64_t st, int64_t end) { auto inSSize = (int64_t)in.size(); @@ -107,8 +107,8 @@ namespace AZ //https://developercommunity.visualstudio.com/content/problem/275141/c2131-expression-did-not-evaluate-to-a-constant-fo.html } - // ability to create size_t literals - // waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo + //! ability to create size_t literals + //! waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo inline constexpr std::size_t operator ""_sz(unsigned long long n) { return n; @@ -145,7 +145,7 @@ namespace AZ return fileName.substr(0, lastIndex); } - // debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library) + //! debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library) template To polymorphic_downcast(From instance) { @@ -157,7 +157,7 @@ namespace AZ return static_cast(instance); } - /// surround a string with a prefix and a suffix + //! surround a string with a prefix and a suffix inline string Decorate(string_view prefix, string_view body, string_view suffix) { std::stringstream ss; @@ -167,13 +167,13 @@ namespace AZ return ss.str(); } - /// 2 arguments version in case both sides are the same + //! 2 arguments version in case both sides are the same inline string Decorate(string_view prefixAndSuffix, string_view body) { return Decorate(prefixAndSuffix, body, prefixAndSuffix); } - /// reverse the effect of a symmetrical decoration + //! reverse the effect of a symmetrical decoration inline string_view Undecorate(string_view decoration, string_view body) { auto indexStart = StartsWith(body, decoration) ? decoration.length() : 0; @@ -181,7 +181,7 @@ namespace AZ return Slice(body, indexStart, indexEnd); } - // Erase-Remove algorithm which removes all whitespaces from a string. + //! Erase-Remove algorithm which removes all whitespaces from a string. inline string RemoveWhitespaces(string haystack) { haystack.erase(std::remove_if(haystack.begin(), haystack.end(), [](unsigned char c) {return std::isspace(c); }), haystack.end()); @@ -193,14 +193,14 @@ namespace AZ return std::all_of(s.begin(), s.end(), [&](char c) { return std::isspace(c); }); } - /// tells whether a position in a string is surrounded by round braces - /// e.g. true for arguments {"a(b)", 2} - /// e.g. true for arguments {"a()", 1} by convention - /// e.g. false for arguments {"a()", 2} by convention - /// e.g. false for arguments {"a(b)", 0} - /// e.g. false for arguments {"a(b)c", 4} - /// e.g. false for arguments {"a(b)c(d)", 4} - /// e.g. true for arguments {"a((b)c(d))", 5} + //! tells whether a position in a string is surrounded by round braces + //! e.g. true for arguments {"a(b)", 2} + //! e.g. true for arguments {"a()", 1} by convention + //! e.g. false for arguments {"a()", 2} by convention + //! e.g. false for arguments {"a(b)", 0} + //! e.g. false for arguments {"a(b)c", 4} + //! e.g. false for arguments {"a(b)c(d)", 4} + //! e.g. true for arguments {"a((b)c(d))", 5} inline bool WithinMatchedParentheses(string_view haystack, size_t charPosition) { const auto hayLen = haystack.length(); @@ -215,8 +215,8 @@ namespace AZ return nesting > 0; } - /// replace all occurrences of substring `sub` with substring `to` within haystack. - /// e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa" + //! replace all occurrences of substring `sub` with substring `to` within haystack. + //! e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa" inline string Replace(string haystack, string_view sub, string_view to) { decltype(sub.length()) pos = 0; @@ -230,7 +230,7 @@ namespace AZ return haystack; } - // this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes. + //! this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes. constexpr inline string_view Trim(string_view haystack, string_view toTrim = " \t\n") { const auto strEnd = haystack.find_last_not_of(toTrim); @@ -268,34 +268,56 @@ namespace AZ return std::find_if(begin, end, p) != end; } - /// argument in rangeV3-style version: + //! argument in rangeV3-style version: template< typename Container > string Join(const Container& c, string_view separator = "") { return Join(c.begin(), c.end(), separator); } - /// argument in rangeV3-style version: + //! argument in rangeV3-style version: template< typename Container, typename Predicate > bool Contains(const Container& c, Predicate p) { return Contains(c.begin(), c.end(), p); } - /// closest possible form of python's `in` keyword + //! closest possible form of python's `in` keyword template< typename Element, typename Container > bool IsIn(const Element& element, const Container& container) { return std::find(container.begin(), container.end(), element) != container.end(); } - /// generate a new container with copy-and-mutated elements + //! generate a new container with copy-and-mutated elements template< typename Container, typename ContainerOut, typename Functor > void TransformCopy(const Container& in, ContainerOut& out, Functor mutator) { std::transform(in.begin(), in.end(), std::back_inserter(out), mutator); } + enum class CopyIfPolicy { ForAll, InterruptAtFirstFalse }; + + //! inserts elements into the output iterator if they pass a predicate + template< typename InputIterator, typename Predicate, typename OutputIterator > + void CopyIf(InputIterator begin, InputIterator end, + Predicate pred, + OutputIterator out, + CopyIfPolicy policy) + { + for (auto it = begin; it != end; ++it) + { + if (pred(*it)) + { + *out = *it; + } + else if (policy == CopyIfPolicy::InterruptAtFirstFalse) + { + break; + } + } + } + inline string Unescape(string_view escapedText) { std::stringstream out; @@ -531,7 +553,7 @@ namespace AZ #endif } - /// Log(N) find immediate lower element query in map-keys + //! Log(N) query to find the first immediately lower or equal element in a map's keys template< typename T, typename U > auto Infimum(map const& ctr, T query) { @@ -539,35 +561,106 @@ namespace AZ return it == ctr.begin() ? ctr.cend() : --it; } - /// Log(N) disjointed segments belong query - /// You can represent segments as you wish, as long as: - /// you provide the predicate to determine belonging. - /// map-keys are segment start points. - /// segments don't overlap. - /// returns: iterator to found interval key, or cend() + //! Log(N) disjointed segments belong query + //! You can represent segments as you wish, as long as: + //! you provide the predicate to determine belonging. + //! map-keys are segment start points. + //! segments don't overlap. + //! returns: iterator to found interval key, or cend() template< typename T, typename U, typename IntervalCheckPredicate > - auto FindInterval(const map& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate) + auto FindIntervalInDisjointSet(const map& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate) { auto inf = Infimum(ctr, query); - return inf == ctr.end() ? ctr.cend() : (isInIntervalPredicate(query, inf->second) ? inf : ctr.cend()); + bool isInInterval = inf != ctr.cend() && isInIntervalPredicate(query, inf->second); + return isInInterval ? inf : ctr.cend(); } - /// Log(N) disjointed segments belong query - /// segments are represented by their start points in the key, and last point in values. segments can't overlap. - /// returns: iterator to found interval key, or cend() + //! Log(N) disjointed segments belong query + //! segments are represented by their start points in the key, and last point in values. segments can't overlap. + //! returns: iterator to found interval key, or cend() template< typename T, typename U> - auto FindInterval(const map& ctr, const T& query) + auto FindIntervalInDisjointSet(const map& ctr, const T& query) { - return FindInterval(ctr, query, [](T q, U last) {return q <= last; }); + return FindIntervalInDisjointSet(ctr, query, [](T q, U last) {return q <= last; }); } + template< typename T > + struct Interval + { + bool IsEmpty() const { return b < a; } + bool operator== (Interval const& rhs) const { return a == rhs.a && b == rhs.b; } + bool operator< (Interval const& rhs) const { return a < rhs.a || (a == rhs.a && b < rhs.b); } + T a = (T)0; + T b = (T)-1; + }; + + //! In case of potential overlaps (not disjointed), this structure can support "is in" queries + template< typename T > + struct IntervalCollection + { + using Interval = Interval; + + //! Construction from an iterable collection of Interval typed elements + template< typename Iterator > + IntervalCollection(Iterator&& begin, Iterator&& end) + : m_obfirsts(begin, end), m_oblasts(begin, end) + { + std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2) + { + return i1.a < i2.a; + }); + std::sort(m_oblasts.begin(), m_oblasts.end(), [](auto i1, auto i2) + { + return i1.b < i2.b; + }); + } + + //! Retrieve the subset of intervals activated by a point (query) + set GetIntervalsSurrounding(T query) + { + // construct the set of intervals starting before: + set startBefore; + CopyIf(m_obfirsts.begin(), m_obfirsts.end(), + [=](auto interv) { return interv.a <= query; }, + std::inserter(startBefore, startBefore.end()), + CopyIfPolicy::InterruptAtFirstFalse); // because the obfirsts vector is sorted + + // construct the set of intervals ending after: + set endAfter; + CopyIf(m_oblasts.rbegin(), m_oblasts.rend(), // reverse iteration + [=](auto interv) { return interv.b >= query; }, + std::inserter(endAfter, endAfter.end()), + CopyIfPolicy::InterruptAtFirstFalse); + + set result; + std::set_intersection(startBefore.begin(), startBefore.end(), + endAfter.begin(), endAfter.end(), + std::inserter(result, result.end())); + return result; + } + + //! Get the interval surrounding query that has the closest start point to query. + //! In case of an interval collection representing a tree, that is, + //! each overlapping interval is fully contained in the bigger one, + //! the closest start is guaranteed to be the most "leaf" interval. + //! This is useful for scopes. + Interval GetClosestIntervalSurrounding(T query) + { + auto bag = GetIntervalsSurrounding(query); + return bag.empty() ? Interval{-1, -2} : *bag.rbegin(); + } + + vector m_obfirsts; // ordered by "firsts" + vector m_oblasts; // ordered by "lasts" + }; + template< typename Deduced > decltype(auto) CastToRValueReference(Deduced&& value) { return static_cast&&>(value); } - /// add a missing operator for convenience and shortness of code + //! add a missing operator for convenience and shortness of code inline bool operator == (string_view lhs, char rhs) { return lhs.length() == 1 && lhs[0] == rhs; @@ -713,10 +806,10 @@ namespace AZ::Tests assert(yellow == intervals.cend()); auto larger_than_all = Infimum(intervals, 15); assert(larger_than_all->first == 8); - assert(FindInterval(intervals, 4) != intervals.cend()); - assert(FindInterval(intervals, 6) == intervals.cend()); - assert(FindInterval(intervals, 8) != intervals.cend()); - assert(FindInterval(intervals, 1) == intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 4) != intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 6) == intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 8) != intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 1) == intervals.cend()); auto high = Infimum(intervals, 20); assert(high->first == 8); @@ -734,6 +827,21 @@ namespace AZ::Tests assert(IsIn("hibou", std::initializer_list{ "chouette", "hibou", "jay" })); assert(!IsIn("hibou", std::initializer_list{ "chouette", "jay" })); + + Interval intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}}; + IntervalCollection ic{std::begin(intvs), std::end(intvs)}; + assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty()); + assert((ic.GetClosestIntervalSurrounding(0) == Interval{0,10})); + assert((ic.GetClosestIntervalSurrounding(1) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(3) == Interval{3,3})); + assert((ic.GetClosestIntervalSurrounding(4) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(6) == Interval{0,10})); + assert((ic.GetClosestIntervalSurrounding(5) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(7) == Interval{7,9})); + assert((ic.GetClosestIntervalSurrounding(9) == Interval{7,9})); + assert(ic.GetClosestIntervalSurrounding(11).IsEmpty()); + assert((ic.GetClosestIntervalSurrounding(13) == Interval{12,15})); + assert(ic.GetClosestIntervalSurrounding(16).IsEmpty()); } } #endif diff --git a/tests/Semantic/typeof-keyword.azsl b/tests/Semantic/typeof-keyword.azsl index 47f34ac..bfdbd74 100644 --- a/tests/Semantic/typeof-keyword.azsl +++ b/tests/Semantic/typeof-keyword.azsl @@ -19,7 +19,7 @@ top gettop(); class A { - int a; + int a; }; class B : A @@ -83,16 +83,14 @@ void h() // NumericConstructorExpression float2(0,0) // LiteralExpression 42 // CommaExpression X, Y - // not supported: // PostfixUnaryExpression i++ // PrefixUnaryExpression ++i // BinaryExpression i + j - // e.g. typeof(1 + 3) = - // mathematics + // mathematics __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(1 + 3), __azslc_prtsym_least_qualified); - __azslc_print_message(" == ''\n"); + __azslc_print_message(" == 'int'\n"); // literals __azslc_print_message("@check predicate "); @@ -186,7 +184,7 @@ void h() __azslc_print_symbol(typeof(top::inner), __azslc_prtsym_fully_qualified); __azslc_print_message(" == '/top/inner'\n"); - // class inheritance: parent member access using scope resolution operator + // class inheritance: parent member access using scope resolution operator __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(B::a), __azslc_prtsym_least_qualified); __azslc_print_message(" == 'int'\n"); @@ -445,4 +443,110 @@ void h() __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(INTVAR, INTVAR, DOUBLEVAR), __azslc_prtsym_least_qualified); __azslc_print_message(" == 'double'\n"); + + // binary arithmetic expression + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'int'\n"); + + // with float promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1.f), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float'\n"); + + // with half promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1.h), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'half'\n"); + + // half and float to float promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.f + 1.h), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float'\n"); + + // int16_t to double promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.l + int16_t(1)), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double'\n"); + + // bool binary + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.l && int16_t(1)), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'bool'\n"); + + // lookedup + double d; + int64_t i64; + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d || i64), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'bool'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d - i64), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double'\n"); + + // vector scalar + float4 f4; + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(2.f * f4), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(f4 * 2.f), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d * f4), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(f4 * d), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double4'\n"); + + // matrix scalar + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float() * float3x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float3x2'\n"); + + // matrix scalar with base type promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(half3x2() * double()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x2'\n"); + + // truncations + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float3() * float2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float2'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float4x4() * float2x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float2x2'\n"); + + // truncate & upcast + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float4x4() * double2x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double2x2'\n"); + + // through alias + typealias d34m = double3x4; + typealias real = half; + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof((d34m)0 * (real)0), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x4'\n"); + // note: the parser takes d32m() as a function call, this is the "most verxing parse" problem + // float() is understood by the parser as a NumericConstructorExpression because it + // has a list of the tokens representing all fundamental types. + // but, with user defined identifiers, it can't branch into the "intended" context, + // because ALL(*) Antlr4 parsers recognizes context free grammar only. + // We could adopt a universal construction syntax Obj{} a la C++ for AZSL but it's not + // compliant with the philosophy of not deviating from HLSL, which on its side doesn't really + // accept constructor-type constructs. DXC just does it better here because clang parser is Turing complete. + + d34m d34var; + real rVar; + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(rVar - d34var ), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x4'\n"); } From 7042ebf23ae3b4122676fc7199ff224922b1ca5c Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 10 Apr 2023 19:18:16 +0900 Subject: [PATCH 08/13] First working order of code that can track method calls cost. Now able to locate the symbols because the starting scope is correctly reconstructed, using the new IntervalCollection class which is able to support query for non-disjointed intervals, which is a more difficult case than what we had up to now. We still keep the previous map to functions because it's faster to query. Signed-off-by: Vivien Oddou --- src/AzslcReflection.cpp | 23 ++++++++++----------- src/AzslcReflection.h | 7 +++++-- src/GenericUtils.h | 44 +++++++++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index c95a0e0..f49f746 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -914,7 +914,7 @@ namespace AZ::ShaderCompiler // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes. // (truth: we need a set of disjoint intervals as an invariant for the following algorithm) - GenerateScopeStartToFunctionIntervalsReverseMap(); + GenerateTokenScopeIntervalToUidReverseMap(); Json::Value srgRoot(Json::objectValue); // Order the reflection by SRG for convenience @@ -1044,8 +1044,8 @@ namespace AZ::ShaderCompiler void CodeReflection::AnalyzeOptionRanks() const { - // make sure we have the function scope lookup cache ready - GenerateScopeStartToFunctionIntervalsReverseMap(); + // make sure we have the scope lookup cache ready + GenerateTokenScopeIntervalToUidReverseMap(); // loop over variables for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3()) { @@ -1121,14 +1121,10 @@ namespace AZ::ShaderCompiler // figure out the scope at this token. // theoretically should be something in the like of the body of another function, // or an anonymous block within another function. - auto intervalIter = FindIntervalInDisjointSet(m_functionIntervals, (ssize_t)callNode->start->getTokenIndex(), - [](ssize_t key, auto& value) - { - return value.first.properlyContains({key, key}); - }); - if (intervalIter != m_functionIntervals.cend()) + auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex()); + if (!interval.IsEmpty()) { - IdentifierUID encloser = intervalIter->second.second; + IdentifierUID encloser = m_intervalToUid[interval]; // Because we are past the end of the semantic analysis, // the scope tracker is registering the last seen scope (surely "/"). @@ -1147,6 +1143,7 @@ namespace AZ::ShaderCompiler else if (auto* maeExpr = As(callNode->Expr)) { startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr); + funcName = ExtractNameFromIdExpression(maeExpr->Member); } IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName); if (!overload) // in case of function not found, we assume it's an intrinsic. @@ -1193,7 +1190,7 @@ namespace AZ::ShaderCompiler } } - void CodeReflection::GenerateScopeStartToFunctionIntervalsReverseMap() const + void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const { if (m_functionIntervals.empty()) { @@ -1204,7 +1201,11 @@ namespace AZ::ShaderCompiler // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound) m_functionIntervals[interval.a] = std::make_pair(interval, uid); } + auto i = Interval{interval.a, interval.b}; + m_intervals.Add(i); + m_intervalToUid[i] = uid; } + m_intervals.Seal(); } } } diff --git a/src/AzslcReflection.h b/src/AzslcReflection.h index c9c3ddf..17b0255 100644 --- a/src/AzslcReflection.h +++ b/src/AzslcReflection.h @@ -12,6 +12,7 @@ namespace AZ::ShaderCompiler { using MapOfBeginToSpanAndUid = map >; + using MapOfIntervalToUid = map, IdentifierUID>; struct CodeReflection : Backend { @@ -89,8 +90,10 @@ namespace AZ::ShaderCompiler void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const; //! Useful for static analysis on dependencies or option ranks - void GenerateScopeStartToFunctionIntervalsReverseMap() const; - mutable MapOfBeginToSpanAndUid m_functionIntervals; //< cache for the result of above function call + void GenerateTokenScopeIntervalToUidReverseMap() const; + mutable MapOfBeginToSpanAndUid m_functionIntervals; //< only functions because they are guaranteed to be disjointed (largely simplifies queries) + mutable IntervalCollection m_intervals; //< augmented version with anonymous blocks (slower query) + mutable MapOfIntervalToUid m_intervalToUid; //< side by side data since we don't want to weight the interval struct with a payload std::ostream& m_out; }; diff --git a/src/GenericUtils.h b/src/GenericUtils.h index a63c656..d2d45fc 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -600,11 +600,15 @@ namespace AZ { using Interval = Interval; - //! Construction from an iterable collection of Interval typed elements - template< typename Iterator > - IntervalCollection(Iterator&& begin, Iterator&& end) - : m_obfirsts(begin, end), m_oblasts(begin, end) + void Add(Interval i) { + m_obfirsts.emplace_back(i); + } + + //! doesn't respect RAII but for the sake of performance and convenience this is easier this way + void Seal() + { + m_oblasts = m_obfirsts; std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2) { return i1.a < i2.a; @@ -613,27 +617,30 @@ namespace AZ { return i1.b < i2.b; }); + m_sealed = true; } //! Retrieve the subset of intervals activated by a point (query) - set GetIntervalsSurrounding(T query) + set GetIntervalsSurrounding(T query) const { - // construct the set of intervals starting before: - set startBefore; - CopyIf(m_obfirsts.begin(), m_obfirsts.end(), - [=](auto interv) { return interv.a <= query; }, - std::inserter(startBefore, startBefore.end()), - CopyIfPolicy::InterruptAtFirstFalse); // because the obfirsts vector is sorted + assert(m_sealed); + // find the "set" of intervals starting before: + auto firstsSubEnd = std::lower_bound(m_obfirsts.begin(), m_obfirsts.end(), + query, + [=](auto interv, T q) { return interv.a <= q; }); - // construct the set of intervals ending after: - set endAfter; + // find the "set" of intervals ending after: + static vector endAfter; + endAfter.clear(); CopyIf(m_oblasts.rbegin(), m_oblasts.rend(), // reverse iteration [=](auto interv) { return interv.b >= query; }, - std::inserter(endAfter, endAfter.end()), + std::back_inserter(endAfter), CopyIfPolicy::InterruptAtFirstFalse); + // for set_intersection to work, the less<> predicate has to work for both ranges + std::sort(endAfter.begin(), endAfter.end()); set result; - std::set_intersection(startBefore.begin(), startBefore.end(), + std::set_intersection(m_obfirsts.begin(), firstsSubEnd, endAfter.begin(), endAfter.end(), std::inserter(result, result.end())); return result; @@ -644,7 +651,7 @@ namespace AZ //! each overlapping interval is fully contained in the bigger one, //! the closest start is guaranteed to be the most "leaf" interval. //! This is useful for scopes. - Interval GetClosestIntervalSurrounding(T query) + Interval GetClosestIntervalSurrounding(T query) const { auto bag = GetIntervalsSurrounding(query); return bag.empty() ? Interval{-1, -2} : *bag.rbegin(); @@ -652,6 +659,7 @@ namespace AZ vector m_obfirsts; // ordered by "firsts" vector m_oblasts; // ordered by "lasts" + bool m_sealed = false; }; template< typename Deduced > @@ -829,7 +837,9 @@ namespace AZ::Tests assert(!IsIn("hibou", std::initializer_list{ "chouette", "jay" })); Interval intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}}; - IntervalCollection ic{std::begin(intvs), std::end(intvs)}; + IntervalCollection ic; + std::for_each(std::begin(intvs), std::end(intvs), [&](auto i) {ic.Add(i); }); + ic.Seal(); assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty()); assert((ic.GetClosestIntervalSurrounding(0) == Interval{0,10})); assert((ic.GetClosestIntervalSurrounding(1) == Interval{1,5})); From d87a18e4a2f2c7e33057faf47f42303ca63e41e4 Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 10 Apr 2023 23:00:56 +0900 Subject: [PATCH 09/13] fix a bug that creates an infinite loop if we kick a lookup from an intrinsic type. For instance "?Texture2D" as scope, and "Load" as method will end up in a "LevelUp" that isn't "/" but is "". The empty path was never never meant to be a possible output of LevelUp function, but it does happen in case of levelup from non rooted symbols. Signed-off-by: Vivien Oddou --- src/AzslcSymbolAggregator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AzslcSymbolAggregator.cpp b/src/AzslcSymbolAggregator.cpp index d1f78ef..ec72d84 100644 --- a/src/AzslcSymbolAggregator.cpp +++ b/src/AzslcSymbolAggregator.cpp @@ -140,7 +140,7 @@ namespace AZ::ShaderCompiler { auto attempt = QualifiedName{JoinPath(path, name)}; got = GetIdAndKindInfo(attempt); - exit = path == "/"; + exit = path == "/" || path.empty(); if (!got) { if (auto* scopeAsClass = GetAsSub(IdentifierUID{GetParentName(attempt)})) // get enclosing class From 8d2ed81a7a2dcc5daa934bbf47c91598a9a39a2a Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 10 Apr 2023 23:53:23 +0900 Subject: [PATCH 10/13] Output the cost to JSON Signed-off-by: Vivien Oddou --- src/AzslcBackend.cpp | 1 + src/AzslcKindInfo.h | 1 + src/AzslcReflection.cpp | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/AzslcBackend.cpp b/src/AzslcBackend.cpp index 07793e7..b1e6e45 100644 --- a/src/AzslcBackend.cpp +++ b/src/AzslcBackend.cpp @@ -449,6 +449,7 @@ namespace AZ::ShaderCompiler // We reserve the right to change it in the future so we make it explicit attribute here shaderOption["order"] = optionOrder; optionOrder++; + shaderOption["costImpact"] = varInfo->m_estimatedCostImpact; bool isUdt = IsUserDefined(varInfo->GetTypeClass()); assert(isUdt || IsPredefinedType(varInfo->GetTypeClass())); diff --git a/src/AzslcKindInfo.h b/src/AzslcKindInfo.h index ae09224..bbcec04 100644 --- a/src/AzslcKindInfo.h +++ b/src/AzslcKindInfo.h @@ -399,6 +399,7 @@ namespace AZ::ShaderCompiler ConstNumericVal m_constVal; // (attempted folded) initializer value for simple scalars optional m_samplerState; ExtendedTypeInfo m_typeInfoExt; + int m_estimatedCostImpact = -1; //!< Cached value calculated by AnalyzeOptionRanks }; // VarInfo methods definitions diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index f49f746..804ea78 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -629,9 +629,9 @@ namespace AZ::ShaderCompiler void CodeReflection::DumpVariantList(const Options& options) const { + AnalyzeOptionRanks(); m_out << GetVariantList(options); m_out << "\n"; - AnalyzeOptionRanks(); } static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo) @@ -1060,7 +1060,7 @@ namespace AZ::ShaderCompiler impactScore += AnalyzeImpact(ref.m_where) // dependent code that may be skipped depending on the value of that ref + 1; // by virtue of being mentioned (seenat), we count the reference as an access of cost 1. } - m_out << "Option " << uid.GetName() << " has impact " << impactScore << "\n"; + varInfo->m_estimatedCostImpact = impactScore; } } } From c274dff1a2959b689b888c6fb14c368d731a8f4f Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Mon, 10 Apr 2023 23:53:41 +0900 Subject: [PATCH 11/13] Integrate a test for option rank cost Signed-off-by: Vivien Oddou --- tests/Advanced/mae-methodcall.azsl | 27 ++++++++++++++++++ tests/Advanced/mae-methodcall.py | 46 ++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/Advanced/mae-methodcall.azsl create mode 100644 tests/Advanced/mae-methodcall.py diff --git a/tests/Advanced/mae-methodcall.azsl b/tests/Advanced/mae-methodcall.azsl new file mode 100644 index 0000000..e11e09a --- /dev/null +++ b/tests/Advanced/mae-methodcall.azsl @@ -0,0 +1,27 @@ +ShaderResourceGroupSemantic slot0 +{ + FrequencyId = 1; + ShaderVariantFallback = 128; +}; +ShaderResourceGroup srg0 : slot0{} + +class C +{ + void f(double) { f(0) + f(0) * f(0) / f(0); f(0) - f(0) % f(0); } // cost 7*7+2 + void f(int) { ;;;;;;; } // cost 7 +}; + +option bool o; + +float4 main() +{ + if (o) // unnamed block $bk0 + { + C c; + { // unnamed block $bk1 to verify lookup capability to find `/main/$bk0/c` from `/main/$bk0/$bk1/` + // understand that `c`'s type is `/C`, and use /C scope to lookup the f() method. + // also deep expression on LHS of MAE to give no break to the typeof system + (c).f(2 * 5.0l); // double promotion in binary expression that resolves to double overload method call + } + } +} \ No newline at end of file diff --git a/tests/Advanced/mae-methodcall.py b/tests/Advanced/mae-methodcall.py new file mode 100644 index 0000000..4ee002f --- /dev/null +++ b/tests/Advanced/mae-methodcall.py @@ -0,0 +1,46 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +""" +Copyright (c) Contributors to the Open 3D Engine Project. +For complete copyright and license terms please see the LICENSE at the root of this distribution. + +SPDX-License-Identifier: Apache-2.0 OR MIT +""" +import sys +import os +sys.path.append("..") +sys.path.append("../..") +from clr import * +import testfuncs + + +def verifyOptionCosts(thefile, compilerPath, silent): + j, ok = testfuncs.buildAndGetJson(thefile, compilerPath, silent, ["--options"]) + if ok: + predicates = [] + # check all references of func() + predicates.append(lambda: j["ShaderOptions"][0]["name"] == "o") + predicates.append(lambda: j["ShaderOptions"][0]["costImpact"] == 54) + + if not silent: print (fg.CYAN+ style.BRIGHT+ "option expected cost check..."+ style.RESET_ALL) + ok = testfuncs.verifyAllPredicates(predicates, j) + return ok + +result = 0 # to define for sub-tests +resultFailed = 0 + +def doTests(compiler, silent, azdxcpath): + global result + global resultFailed + + # Working directory should have been set to this script's directory by the calling parent + # You can get it once doTests() is called, but not during initialization of the module, + # because at that time it will still be set to the working directory of the calling script + workDir = os.getcwd() + + if verifyOptionCosts(os.path.join(workDir, "mae-methodcall.azsl"), compiler, silent): result += 1 + else: resultFailed += 1 + + +if __name__ == "__main__": + print ("please call from testapp.py") From a02b39983f1f7d09193449b80d05f82e43399b38 Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Tue, 11 Apr 2023 00:02:10 +0900 Subject: [PATCH 12/13] Actually this test now no longer fails since the type resolution has gained in power. To fallback on an exhibition of the problem again it's enough to just mention the call to floor which is unregistered as long as azslc is concerned. Signed-off-by: Vivien Oddou --- .../overload-resolution-impossible-and-heteroreturn.azsl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl b/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl index f9fc391..ea7770c 100644 --- a/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl +++ b/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl @@ -2,12 +2,13 @@ struct A {}; struct B {}; A make(int); -B make(uint); +B make(float); void main() { float x = 0.5; - A a = make((int)floor(x) + 1); // #EC 41 + A a = make(floor(x)); // #EC 41 + // make((int)floor(x)); // help azslc knowing about unregistered functions by casting to force the type. } /*Semantic Error 41: line 10::14 '(10): unable to match arguments () to a registered overload. candidates are: /make(?int) From f7d55d4606d2d504304d837ddd041b20d3e6699f Mon Sep 17 00:00:00 2001 From: Vivien Oddou Date: Wed, 19 Apr 2023 20:08:16 +0900 Subject: [PATCH 13/13] =?UTF-8?q?fix=20clang=20complaint=20about=20somethi?= =?UTF-8?q?ng=20that=20visual=20studio=20tolerated.=20(declaration=20of'?= =?UTF-8?q?=20x=E2=80=99changes=20meaning=20of=20'x')?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Vivien Oddou --- src/GenericUtils.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/GenericUtils.h b/src/GenericUtils.h index d2d45fc..627c466 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -598,9 +598,9 @@ namespace AZ template< typename T > struct IntervalCollection { - using Interval = Interval; + using IntervalT = Interval; - void Add(Interval i) + void Add(IntervalT i) { m_obfirsts.emplace_back(i); } @@ -621,7 +621,7 @@ namespace AZ } //! Retrieve the subset of intervals activated by a point (query) - set GetIntervalsSurrounding(T query) const + set GetIntervalsSurrounding(T query) const { assert(m_sealed); // find the "set" of intervals starting before: @@ -630,7 +630,7 @@ namespace AZ [=](auto interv, T q) { return interv.a <= q; }); // find the "set" of intervals ending after: - static vector endAfter; + static vector endAfter; endAfter.clear(); CopyIf(m_oblasts.rbegin(), m_oblasts.rend(), // reverse iteration [=](auto interv) { return interv.b >= query; }, @@ -639,7 +639,7 @@ namespace AZ // for set_intersection to work, the less<> predicate has to work for both ranges std::sort(endAfter.begin(), endAfter.end()); - set result; + set result; std::set_intersection(m_obfirsts.begin(), firstsSubEnd, endAfter.begin(), endAfter.end(), std::inserter(result, result.end())); @@ -651,14 +651,14 @@ namespace AZ //! each overlapping interval is fully contained in the bigger one, //! the closest start is guaranteed to be the most "leaf" interval. //! This is useful for scopes. - Interval GetClosestIntervalSurrounding(T query) const + IntervalT GetClosestIntervalSurrounding(T query) const { auto bag = GetIntervalsSurrounding(query); - return bag.empty() ? Interval{-1, -2} : *bag.rbegin(); + return bag.empty() ? IntervalT{-1, -2} : *bag.rbegin(); } - vector m_obfirsts; // ordered by "firsts" - vector m_oblasts; // ordered by "lasts" + vector m_obfirsts; // ordered by "firsts" + vector m_oblasts; // ordered by "lasts" bool m_sealed = false; };