diff --git a/src/AzslcBackend.cpp b/src/AzslcBackend.cpp index 07793e7..b1e6e45 100644 --- a/src/AzslcBackend.cpp +++ b/src/AzslcBackend.cpp @@ -449,6 +449,7 @@ namespace AZ::ShaderCompiler // We reserve the right to change it in the future so we make it explicit attribute here shaderOption["order"] = optionOrder; optionOrder++; + shaderOption["costImpact"] = varInfo->m_estimatedCostImpact; bool isUdt = IsUserDefined(varInfo->GetTypeClass()); assert(isUdt || IsPredefinedType(varInfo->GetTypeClass())); diff --git a/src/AzslcEmitter.cpp b/src/AzslcEmitter.cpp index f800da9..605728d 100644 --- a/src/AzslcEmitter.cpp +++ b/src/AzslcEmitter.cpp @@ -1277,7 +1277,6 @@ namespace AZ::ShaderCompiler const ICodeEmissionMutator* codeMutator = m_codeMutator; ssize_t ii = interval.a; - bool wasInPreprocessorDirective = false; // record a state to detect exit of directives, because they need to reside on their own lines while (ii <= interval.b) { auto* token = GetNextToken(ii /*inout*/); diff --git a/src/AzslcIntermediateRepresentation.cpp b/src/AzslcIntermediateRepresentation.cpp index f89a559..327043d 100644 --- a/src/AzslcIntermediateRepresentation.cpp +++ b/src/AzslcIntermediateRepresentation.cpp @@ -332,7 +332,6 @@ namespace AZ::ShaderCompiler cout << " storage: " << sub.m_typeInfoExt.m_qualifiers.GetDisplayName() << "\n"; cout << " array dim: \"" << sub.m_typeInfoExt.m_arrayDims.ToString() << "\"\n"; cout << " has sampler state: " << (sub.m_samplerState ? "yes\n" : "no\n"); - cout << "\n"; if (!holds_alternative(sub.m_constVal)) { cout << " val: " << ExtractValueAsInt64(sub.m_constVal) << "\n"; @@ -519,7 +518,7 @@ namespace AZ::ShaderCompiler if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } nextMemberStartingOffset = Packing::PackNextChunk(layoutPacking, size, startAt); @@ -960,5 +959,4 @@ namespace AZ::ShaderCompiler } return memberList[memberList.size() - 1]; } - } // end of namespace AZ::ShaderCompiler diff --git a/src/AzslcIntermediateRepresentation.h b/src/AzslcIntermediateRepresentation.h index ab7c43b..42fc0fa 100644 --- a/src/AzslcIntermediateRepresentation.h +++ b/src/AzslcIntermediateRepresentation.h @@ -14,7 +14,6 @@ namespace AZ::ShaderCompiler { - //! We limit the maximum number of render targets to 8, with indices in the range [0..7] static const uint32_t kMaxRenderTargets = 8; diff --git a/src/AzslcKindInfo.h b/src/AzslcKindInfo.h index ef50968..bbcec04 100644 --- a/src/AzslcKindInfo.h +++ b/src/AzslcKindInfo.h @@ -248,7 +248,7 @@ namespace AZ::ShaderCompiler //! Get the size of a single element, ignoring array dimensions const uint32_t GetSingleElementSize(Packing::Layout layout, bool defaultRowMajor) const { - auto baseSize = m_coreType.m_arithmeticInfo.GetBaseSize(); + auto baseSize = m_coreType.m_arithmeticInfo.m_baseSize; bool isRowMajor = (m_mtxMajor == Packing::MatrixMajor::RowMajor || (m_mtxMajor == Packing::MatrixMajor::Default && defaultRowMajor)); auto rows = m_coreType.m_arithmeticInfo.m_rows; @@ -399,6 +399,7 @@ namespace AZ::ShaderCompiler ConstNumericVal m_constVal; // (attempted folded) initializer value for simple scalars optional m_samplerState; ExtendedTypeInfo m_typeInfoExt; + int m_estimatedCostImpact = -1; //!< Cached value calculated by AnalyzeOptionRanks }; // VarInfo methods definitions @@ -791,6 +792,7 @@ namespace AZ::ShaderCompiler vector< IdentifierUID > m_overrides; //!< list of implementing functions in child classes optional< IdentifierUID > m_base; //!< points to the overridden function in the base interface, if applies. only supports one base FunctionMultiForwards m_multiFwds = FMF_None; //!< presence of redundant prototype-only declarations + int m_costScore = -1; //!< heuristical static analysis of the amount of instructions contained struct Parameter { IdentifierUID m_varId; diff --git a/src/AzslcMain.cpp b/src/AzslcMain.cpp index 87f1e96..0159d7c 100644 --- a/src/AzslcMain.cpp +++ b/src/AzslcMain.cpp @@ -23,8 +23,7 @@ namespace StdFs = std::filesystem; // For large features or milestones. Minor version allows for breaking changes. Existing tests can change. #define AZSLC_MINOR "8" // last change: introduction of class inheritance // For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change. -#define AZSLC_REVISION "17" // last change: fixup alignment check logic_error because of lack of an inter-scope check limiter. - // "16" change: fixup runtime error with redundant function declarations +#define AZSLC_REVISION "18" // last change: automatic option ranks namespace AZ::ShaderCompiler { diff --git a/src/AzslcReflection.cpp b/src/AzslcReflection.cpp index 53d3e7a..804ea78 100644 --- a/src/AzslcReflection.cpp +++ b/src/AzslcReflection.cpp @@ -589,7 +589,7 @@ namespace AZ::ShaderCompiler else if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = m_ir->GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } offset = Packing::PackNextChunk(layoutPacking, size, startAt); @@ -629,7 +629,9 @@ namespace AZ::ShaderCompiler void CodeReflection::DumpVariantList(const Options& options) const { + AnalyzeOptionRanks(); m_out << GetVariantList(options); + m_out << "\n"; } static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo) @@ -857,11 +859,12 @@ namespace AZ::ShaderCompiler for (auto& seenat : kindInfo->GetSeenats()) { assert(uid == seenat.m_referredDefinition); - // TODO: the assumption that intervals where distinct doesnt hold anymore now that we have unnamed scopes - auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value) - { - return value.first.properlyContains({key, key}); - }); + // careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes) + // ok for now because AZSL/HLSL don't have lambdas + auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value) + { + return value.first.properlyContains({key, key}); + }); if (intervalIter != scopes.cend()) { const IdentifierUID& encloser = intervalIter->second.second; @@ -909,16 +912,9 @@ namespace AZ::ShaderCompiler uint32_t numOf32bitConst = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID); RootSigDesc rootSignature = BuildSignatureDescription(options, numOf32bitConst); - // prepare a lookup acceleration data structure for reverse mapping tokens to scopes. - MapOfBeginToSpanAndUid scopeStartToFunctionIntervals; - for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals) - { - if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. We need a set of disjoint intervals as an invariant for the next algorithm. - { - // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound) - scopeStartToFunctionIntervals[interval.a] = std::make_pair(interval, uid); - } - } + // Prepare a lookup acceleration data structure for reverse mapping tokens to scopes. + // (truth: we need a set of disjoint intervals as an invariant for the following algorithm) + GenerateTokenScopeIntervalToUidReverseMap(); Json::Value srgRoot(Json::objectValue); // Order the reflection by SRG for convenience @@ -968,7 +964,7 @@ namespace AZ::ShaderCompiler else { set dependencyList; - DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals); srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {}); } } @@ -981,7 +977,7 @@ namespace AZ::ShaderCompiler for (auto& srgConstant : srgInfo->m_implicitStruct.GetMemberFields()) { allConstants.append({ srgConstant.GetNameLeaf() }); - DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, m_functionIntervals); } // variant fallback support if (srgInfo->m_shaderVariantFallback) @@ -992,7 +988,7 @@ namespace AZ::ShaderCompiler { if (varSub->CheckHasStorageFlag(StorageFlag::Option)) { - DiscoverTopLevelFunctionDependencies(varUid, dependencyList, scopeStartToFunctionIntervals); + DiscoverTopLevelFunctionDependencies(varUid, dependencyList, m_functionIntervals); } } } @@ -1004,4 +1000,212 @@ namespace AZ::ShaderCompiler m_out << srgRoot; } + + // Helper routine for option rank analysis + static int GuesstimateIntrinsicFunctionCost(string_view funcName) + { + if (IsOneOf(funcName, "CallShader", "TraceRay")) + { // non measurable but assumed high + return 100; + } + else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append")) + { // memory access, locked or not, will have high latency + return 10; + } + else + { // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1. + return 1; + } + } + + // Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting. + // The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains + // strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any. + static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir, + azslParser::FunctionCallExpressionContext* callNode, + KindInfo& overload) + { + IdentifierUID concrete; + size_t numArgs = NumArgs(callNode); + overload.GetSubAs()->AnyOf( + [&](IdentifierUID const& uid) + { + auto* concreteFcInfo = ir->GetSymbolSubAs(uid.GetName()); + size_t numParams = concreteFcInfo->GetParameters(true).size(); + if (numParams == numArgs) + { + concrete = uid; // we write the result through reference capture (not clean but convenient) + return true; + } + return false; + }); + return concrete; + } + + void CodeReflection::AnalyzeOptionRanks() const + { + // make sure we have the scope lookup cache ready + GenerateTokenScopeIntervalToUidReverseMap(); + // loop over variables + for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3()) + { + // only options + if (varInfo->CheckHasStorageFlag(StorageFlag::Option)) + { + int impactScore = 0; + // loop over appearances over the program + for (Seenat& ref : kindInfo->GetSeenats()) + { + // determine an impact score + impactScore += AnalyzeImpact(ref.m_where) // dependent code that may be skipped depending on the value of that ref + + 1; // by virtue of being mentioned (seenat), we count the reference as an access of cost 1. + } + varInfo->m_estimatedCostImpact = impactScore; + } + } + } + + int CodeReflection::AnalyzeImpact(TokensLocation const& location) const + { + // find the node at `location`: + ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId); + // go up tree to meet a block node that has visitable depth: + // can be any of if/for/while/switch + // 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for + if (auto* whileNode = DeepParentAs(node->parent, 3)) + { + node = whileNode->embeddedStatement(); + } + else if (auto* ifNode = DeepParentAs(node->parent, 3)) + { + node = ifNode->embeddedStatement(); + } + else if (auto* forNode = DeepParentAs(node->parent, 4)) + { + node = forNode->embeddedStatement(); + } + else if (auto* switchNode = DeepParentAs(node->parent, 3)) + { + node = switchNode->switchBlock(); + } + int score = 0; + AnalyzeImpact(node, score); + return score; + } + + void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const + { + for (auto& c : astNode->children) + { + if (auto* callNode = As(c)) + { + // branch into an overload specialized for function lookup: + AnalyzeImpact(callNode, scoreAccumulator); + } + else if (auto* node = As(c)) + { + AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;" + } + if (auto* leaf = As(c)) + { + // determine cost by number of full expressions separated by semicolon + scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; // bool as 0 or 1 trick + } + } + } + + void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const + { + // to access the function symbol info we need the current scope, the function call name and perform a lookup. + + // figure out the scope at this token. + // theoretically should be something in the like of the body of another function, + // or an anonymous block within another function. + auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex()); + if (!interval.IsEmpty()) + { + IdentifierUID encloser = m_intervalToUid[interval]; + + // Because we are past the end of the semantic analysis, + // the scope tracker is registering the last seen scope (surely "/"). + // This is a stateful side-effect system unfortunately, and since we'll call + // some feature of the semantic orchestrator (like TypeofExpr) we need to hack + // the scope tracker: + m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName(); + m_ir->m_sema.m_scope->UpdateCurScopeUID(); + + QualifiedName startupLookupScope = encloser.GetName(); + UnqualifiedName funcName; + if (auto* idExpr = As(callNode->Expr)) + { + funcName = ExtractNameFromIdExpression(idExpr->idExpression()); + } + else if (auto* maeExpr = As(callNode->Expr)) + { + startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr); + funcName = ExtractNameFromIdExpression(maeExpr->Member); + } + IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName); + if (!overload) // in case of function not found, we assume it's an intrinsic. + { + scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName); + } + else + { + azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode); + IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args); + IdentifierUID concrete; + if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet) + { // in case of strict selection failure, run a fuzzy select + concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second); + // if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit + } + else + { + concrete = symbolMeantUnderCallNode->first; + } + + if (auto* funcInfo = m_ir->GetSymbolSubAs(concrete.GetName())) + { + if (funcInfo->m_costScore == -1) // cost not yet discovered for this function + { + funcInfo->m_costScore = 0; + using AstFDef = azslParser::HlslFunctionDefinitionContext; + AnalyzeImpact(polymorphic_downcast(funcInfo->m_defNode->parent)->block(), + funcInfo->m_costScore); // recurse and cache + } + scoreAccumulator += funcInfo->m_costScore; + } + } + // other cases forfeited for now, but that would at least include things like eg braces (f)() + } + else // no interval found + { + // function calls outside of function bodies can appear in an initializer: + // int g_a = MakeA(); // global init + // class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL) + // class D { void Method(int a_a = DefaultA()); // default parameter value + // in any case, extracting the scope is impossible with this system. + // we forfeit evaluation of a score + } + } + + void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const + { + if (m_functionIntervals.empty()) + { + for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals) + { + if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. + { + // the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound) + m_functionIntervals[interval.a] = std::make_pair(interval, uid); + } + auto i = Interval{interval.a, interval.b}; + m_intervals.Add(i); + m_intervalToUid[i] = uid; + } + m_intervals.Seal(); + } + } } diff --git a/src/AzslcReflection.h b/src/AzslcReflection.h index c3c0bc4..17b0255 100644 --- a/src/AzslcReflection.h +++ b/src/AzslcReflection.h @@ -11,6 +11,9 @@ namespace AZ::ShaderCompiler { + using MapOfBeginToSpanAndUid = map >; + using MapOfIntervalToUid = map, IdentifierUID>; + struct CodeReflection : Backend { CodeReflection(IntermediateRepresentation* ir, TokenStream* tokens, std::ostream& out) @@ -45,6 +48,9 @@ namespace AZ::ShaderCompiler //! @param options user configuration parsed from command line void DumpResourceBindingDependencies(const Options& options) const; + //! Determine a heurisitcal global order between options in the program, using "impacted code size" static analysis. + void AnalyzeOptionRanks() const; + private: //! Builds member variable packing information and adds it to the membersContainer @@ -63,7 +69,6 @@ namespace AZ::ShaderCompiler bool BuildOMStruct(const ExtendedTypeInfo& returnTypeRef, string_view semanticOverride, Json::Value& jsonVal, int& semanticIndex) const; - using MapOfBeginToSpanAndUid = map >; //! Populate a list of functions where a symbol appear as potentially used //! @param uid The symbol to start the dependency analysis on //! @param output Any dependency symbol will be appended to this set @@ -75,6 +80,21 @@ namespace AZ::ShaderCompiler bool IsPotentialEntryPoint(const IdentifierUID& uid) const; + // Estimate a score proportional to how much code is "child" to the AST node at `location` + int AnalyzeImpact(TokensLocation const& location) const; + + // Recursive internal detail version + void AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const; + + // Function-call specific + void AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const; + + //! Useful for static analysis on dependencies or option ranks + void GenerateTokenScopeIntervalToUidReverseMap() const; + mutable MapOfBeginToSpanAndUid m_functionIntervals; //< only functions because they are guaranteed to be disjointed (largely simplifies queries) + mutable IntervalCollection m_intervals; //< augmented version with anonymous blocks (slower query) + mutable MapOfIntervalToUid m_intervalToUid; //< side by side data since we don't want to weight the interval struct with a payload + std::ostream& m_out; }; } diff --git a/src/AzslcSemanticOrchestrator.cpp b/src/AzslcSemanticOrchestrator.cpp index b17b061..67ccc13 100644 --- a/src/AzslcSemanticOrchestrator.cpp +++ b/src/AzslcSemanticOrchestrator.cpp @@ -1166,7 +1166,10 @@ namespace AZ::ShaderCompiler As(ctx), As(ctx), As(ctx), - As(ctx)); + As(ctx), + As(ctx), + As(ctx), + As(ctx)); } catch (AllNull&) { @@ -1175,6 +1178,48 @@ namespace AZ::ShaderCompiler } } + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const + { + // among all possibilities Plus|Minus|Not|Tilde|PlusPlus|MinusMinus + // only "Not" returns a bool, the rest is transparent and returns the same type as rhs + return ctx->prefixUnaryOperator()->start->getType() == azslLexer::Not ? MangleScalarType("bool") + : TypeofExpr(ctx->Expr); + } + + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const + { + return TypeofExpr(ctx->Expr); // in case of x++ or x-- the type is the type of x. + } + + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::BinaryExpressionContext* ctx) const + { + using lex = azslLexer; + auto boolResultOperators = {lex::Less, lex::Greater, lex::LessEqual, lex::GreaterEqual, lex::NotEqual, lex::AndAnd, lex::OrOr}; + if (IsIn(ctx->binaryOperator()->start->getType(), boolResultOperators)) + { + return MangleScalarType("bool"); + } + QualifiedName lhs = TypeofExpr(ctx->Left); + QualifiedName rhs = TypeofExpr(ctx->Right); + TypeRefInfo typeInfoLhs = CreateTypeRefInfo(UnqualifiedNameView{lhs}); // We tolerate a cast here because GetTypeRefInfo was designed to lookup types, but TypeofExpr has already looked up the type. + TypeRefInfo typeInfoRhs = CreateTypeRefInfo(UnqualifiedNameView{rhs}); + if (typeInfoLhs.m_arithmeticInfo.IsEmpty() || typeInfoRhs.m_arithmeticInfo.IsEmpty()) + { // Case that shouldn't work in AZSL yet (but may work in HLSL2021) + // -> UDT operator (would need support of operator overloading). + // We arbitrarily assume a result "type of left expression". + // (what we would really need to do is go get the return type of the overloaded operator) + return lhs; + } // After this `if`, both sides are arithmetic types (scalar, vector, matrix). + // "matrix op vector" (or commutated) is a forbidden case, + // e.g it won't do Y=MX (m*v->v), nor dotproduct-ing vectors for that matter (v*v->scalar) + // It will do piecewise `op` and return more or less the same type as the operands. + // In case of dimension differences, it will truncate to the smaller type + // e.g float2 + float3 results in float2 with .z lost in implicit cast + // same for float2x3 * float2x2 (results in float2x2) + // We assume that for all non bool ops: * + - / % ^ | & << >> + return PromoteTruncateWith({typeInfoLhs.m_arithmeticInfo, typeInfoRhs.m_arithmeticInfo}).GenTypeId(); + } + QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::ExpressionExtContext* ctx) const { return VisitFirstNonNull([this](auto* ctx) { return TypeofExpr(ctx); }, @@ -1303,29 +1348,23 @@ namespace AZ::ShaderCompiler QualifiedName SemanticOrchestrator::TypeofExpr(azslParser::LiteralContext* ctx) const { - // verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array. - auto checkExistType = [](string_view scalarName){return std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), scalarName) != AZ::ShaderCompiler::Predefined::Scalar.end();}; // verifies that last or 1-before-last characters are a particular literal suffix. like in "56ul" auto hasSuffix = [](auto node, char s){return tolower(node->getText().back()) == s || tolower(Slice(node->getText(), -3, -2) == s);}; - auto checkAndReturn = [&](string_view typeName) - { - assert(checkExistType(typeName)); - return QualifiedName{"?" + string{typeName}}; - }; + if (ctx->True() || ctx->False()) { - return checkAndReturn("bool"); + return MangleScalarType("bool"); } else if (auto* literal = ctx->IntegerLiteral()) { - return hasSuffix(literal, 'u') ? checkAndReturn("uint") - : checkAndReturn("int"); + return hasSuffix(literal, 'u') ? MangleScalarType("uint") + : MangleScalarType("int"); } else if (auto* literal = ctx->FloatLiteral()) { - return hasSuffix(literal, 'h') ? checkAndReturn("half") - : hasSuffix(literal, 'l') ? checkAndReturn("double") - : checkAndReturn("float"); + return hasSuffix(literal, 'h') ? MangleScalarType("half") + : hasSuffix(literal, 'l') ? MangleScalarType("double") + : MangleScalarType("float"); } return {""}; } diff --git a/src/AzslcSemanticOrchestrator.h b/src/AzslcSemanticOrchestrator.h index a2aba91..6492c5a 100644 --- a/src/AzslcSemanticOrchestrator.h +++ b/src/AzslcSemanticOrchestrator.h @@ -222,6 +222,9 @@ namespace AZ::ShaderCompiler auto TypeofExpr(azslParser::LiteralExpressionContext* ctx) const -> QualifiedName; auto TypeofExpr(azslParser::LiteralContext* ctx) const -> QualifiedName; auto TypeofExpr(azslParser::CommaExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::PostfixUnaryExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::PrefixUnaryExpressionContext* ctx) const -> QualifiedName; + auto TypeofExpr(azslParser::BinaryExpressionContext* ctx) const -> QualifiedName; //! Parse the AST from a variable declaration and attempt to extract array dimensions integer constants [dim1][dim2]... //! Return: on success, otherwise @@ -330,7 +333,7 @@ namespace AZ::ShaderCompiler { auto typeId = LookupType(typeNameOrCtx, policy); auto tClass = GetTypeClass(typeId); - auto arithmetic = IsNonGenericArithmetic(tClass) ? CreateArithmeticTypeInfo(typeId.GetName()) : ArithmeticTypeInfo{}; // TODO: canonicalize generics + auto arithmetic = IsNonGenericArithmetic(tClass) ? CreateArithmeticTraits(typeId.GetName()) : ArithmeticTraits{}; // TODO: canonicalize generics return TypeRefInfo{typeId, arithmetic, tClass}; } diff --git a/src/AzslcSymbolAggregator.cpp b/src/AzslcSymbolAggregator.cpp index 45ac3b1..ec72d84 100644 --- a/src/AzslcSymbolAggregator.cpp +++ b/src/AzslcSymbolAggregator.cpp @@ -20,7 +20,7 @@ namespace AZ::ShaderCompiler auto& [uid, kindInfo] = st.AddIdentifier(azirName, Kind::Type); // the kind is Type because all predefined are stored as such. auto& typeInfo = kindInfo.GetSubAfterInitAs(); auto typeClass = TypeClass::FromStr(bag.m_name); - auto arithmetic = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTypeInfo(azirName) : ArithmeticTypeInfo{}; // TODO: canonicalize generics + auto arithmetic = IsNonGenericArithmetic(typeClass) ? CreateArithmeticTraits(azirName) : ArithmeticTraits{}; // TODO: canonicalize generics typeInfo = TypeRefInfo{uid, arithmetic, typeClass}; } } @@ -140,7 +140,7 @@ namespace AZ::ShaderCompiler { auto attempt = QualifiedName{JoinPath(path, name)}; got = GetIdAndKindInfo(attempt); - exit = path == "/"; + exit = path == "/" || path.empty(); if (!got) { if (auto* scopeAsClass = GetAsSub(IdentifierUID{GetParentName(attempt)})) // get enclosing class diff --git a/src/AzslcTypes.h b/src/AzslcTypes.h index 6ffd6ea..a3503cf 100644 --- a/src/AzslcTypes.h +++ b/src/AzslcTypes.h @@ -260,36 +260,84 @@ namespace AZ::ShaderCompiler return result; } - /// Rows and Cols (this is specific to shader languages to identify vector and matrix types) - struct ArithmeticTypeInfo + /// Verifies that our hardcoded strings don't have typo, by checking against the lexer-extracted keywords stored in the Scalar array. + inline bool CheckExistScalarType(string_view scalarName) { - void ResolveSize() - { - m_size = Packing::PackedSizeof(m_underlyingScalar); - } + return std::find(Predefined::Scalar.begin(), + Predefined::Scalar.end(), + scalarName) != Predefined::Scalar.end(); + }; + + /// Assert validity of the type string, and form a "?" tainted string to host a scalar type + inline QualifiedName MangleScalarType(string_view typeName) + { + assert(CheckExistScalarType(typeName)); + return QualifiedName{"?" + string{typeName}}; + }; - /// Get the size of a single base element - const uint32_t GetBaseSize() const + //! Holds arithmetic-type-class information in small pieces (row, cols, base, size, rank...) + struct ArithmeticTraits + { + void ResolveBaseSizeAndRank() { - return m_size; + m_baseSize = Packing::PackedSizeof(m_underlyingScalar); + // establish the conversion rank: + auto getIndex = [](string_view s) -> int + { + auto const& Scalars = Predefined::Scalar; + return ::std::distance(Scalars.begin(), + ::std::find(Scalars.begin(), Scalars.end(), s)); + }; + // According to https://en.cppreference.com/w/cpp/language/usual_arithmetic_conversions + // - No two signed have the same rank (even if same siezeof) + // - rank of unsigned = rank of corresponding signed + // - "standard" is > "extended" of same sizeof + // - The rank of bool is the smallest + // That said, we will take inspiration from ASTContext::getIntegerRank of clang + static const unordered_map subranks = + { + {getIndex("bool"), 1}, + {getIndex("int16_t"), 2}, + {getIndex("uint16_t"), 3}, // unsigned wins in case of subrank draw, according to arithmetic conversion rules + {getIndex("int"), 4}, + {getIndex("uint"), 5}, + {getIndex("dword"), 6}, + {getIndex("int32_t"), 7}, + {getIndex("uint32_t"), 8}, + {getIndex("int64_t"), 9}, + {getIndex("uint64_t"), 10}, + {getIndex("half"), 11 << 5}, // floats win all conversions, even halfs + {getIndex("float"), 12 << 5}, + {getIndex("double"), 13 << 5}, + }; + // `basesize` getter, but 1 for bool: (physical size of extern bool is considered 32bits in HLSL) + auto getRankSizeof = [&](int scalarId) + { + assert(string_view{"bool"} == Predefined::Scalar[0]); // verify that 0 is the hard index of bool. + bool isBool = scalarId == 0; + return isBool ? 1 : m_baseSize; + }; + // The shift method is taken from clang, I suppose it's a multi-parameter order cramed into bits. + // so because 10 is the largest subrank, shift by 4 should separate sizeof space and subrank space. + m_conversionRank = (getRankSizeof(m_underlyingScalar) << 4) + subranks.at(m_underlyingScalar); } - /// Get the size of a the element with regard to dimensions as well - const uint32_t GetTotalSize() const + //! Get the size of the whole type considering dimensions + uint32_t GetTotalSize() const { - return m_size * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); + return m_baseSize * (m_cols > 0 ? m_cols : 1) * (m_rows > 0 ? m_rows : 1); } - /// True if the type is a vector type. If it's a vector type it cannot be a matrix as well. - const bool IsVector() const + //! True if the type is a vector type. If it's a vector type it cannot be a matrix as well. + bool IsVector() const { // This treats special cases like 2x1, 3x1 and 4x1 as vectors // The behavior is consistent with dxc packing rules return (m_cols == 1 && m_rows > 1) || (m_cols > 1 && m_rows == 0); } - /// True if the type is a matrix type. If it's a matrix type it cannot be a vector as well. - const bool IsMatrix() const + //! True if the type is a matrix type. If it's a matrix type it cannot be a vector as well. + bool IsMatrix() const { // This fails special cases like 2x1, 3x1 and 4x1, // but allows cases like 1x2, 1x3 and 1x4. @@ -297,23 +345,46 @@ namespace AZ::ShaderCompiler return m_rows > 0 && m_cols > 1; } - /// If initialized as a fundamental -> not empty. - const bool IsEmpty() const + bool IsScalar() const + { + return m_rows <= 1 && m_cols <= 1; // float, float1, float1x1 are scalars. + } + + //! Non-created state + bool IsEmpty() const { return m_underlyingScalar == -1; } - // for pretty print + //! For pretty print string UnderlyingScalarToStr() const { return m_underlyingScalar >= 0 && m_underlyingScalar < AZ::ShaderCompiler::Predefined::Scalar.size() ? AZ::ShaderCompiler::Predefined::Scalar[m_underlyingScalar] : ""; } - uint32_t m_size = 0; // In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct - uint32_t m_rows = 0; // 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix - uint32_t m_cols = 0; // 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix - int m_underlyingScalar = -1; // index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf. + //! Create a canonicalized mangled name that should represent the identity of this arithmetic type. + QualifiedName GenTypeId() const + { + if (IsMatrix()) + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + ToString(m_rows) + "x" + ToString(m_cols)}; + } + else if (IsVector()) + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr()) + (m_rows > 0 ? ToString(m_rows) : ToString(m_cols))}; + } + else + { + return QualifiedName{MangleScalarType(UnderlyingScalarToStr())}; + } + } + + uint32_t m_baseSize = 0; //< In bytes. Size of 0 indicates TypeRefInfo which hasn't been resolved or is a struct + uint32_t m_rows = 0; //< 0 means it's not a matrix (effective Rows = 1). 1 or more means a Matrix + uint32_t m_cols = 0; //< 0 means it's not a vector (effective Cols = 1). 1 or more means a Vector or Matrix + int m_conversionRank = 0; //< Used in conversions and promotions + int m_underlyingScalar = -1; //< Index into AZ::ShaderCompiler::Predefined::Scalar, all fundamentals end up in a scalar at its leaf. }; //! TypeRefInfo holds resolved immutable information of a core type (the `matrix2x2` in `column_major matrix2x2 a[3];`) @@ -323,7 +394,7 @@ namespace AZ::ShaderCompiler struct TypeRefInfo { TypeRefInfo() = default; - TypeRefInfo(IdentifierUID typeId, const ArithmeticTypeInfo& fundamentalInfo, TypeClass typeClass) + TypeRefInfo(IdentifierUID typeId, const ArithmeticTraits& fundamentalInfo, TypeClass typeClass) : m_arithmeticInfo{fundamentalInfo}, m_typeClass{typeClass}, m_typeId{typeId} @@ -359,19 +430,19 @@ namespace AZ::ShaderCompiler return !operator==(lhs,rhs); } - IdentifierUID m_typeId; - TypeClass m_typeClass; - ArithmeticTypeInfo m_arithmeticInfo; + IdentifierUID m_typeId; + TypeClass m_typeClass; + ArithmeticTraits m_arithmeticInfo; }; //! Run a syntactic analysis of an arithmetic type name and extract info on its composition - inline ArithmeticTypeInfo CreateArithmeticTypeInfo(QualifiedName a_typeName) + inline ArithmeticTraits CreateArithmeticTraits(QualifiedName a_typeName) { assert(IsArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) )); // no need to call this function if you don't have a fundamental (non void) assert(!IsGenericArithmetic( /*slow*/AnalyzeTypeClass(TentativeName{a_typeName}) )); // ↑ fatal aspect. The input needs to be canonicalized earlier to minimize this function's complexity. - ArithmeticTypeInfo toReturn; + ArithmeticTraits toReturn; string typeName = UnMangle(a_typeName); size_t baseTypeLen = typeName.length(); @@ -422,10 +493,40 @@ namespace AZ::ShaderCompiler auto it = ::std::find(AZ::ShaderCompiler::Predefined::Scalar.begin(), AZ::ShaderCompiler::Predefined::Scalar.end(), baseType); assert(it != AZ::ShaderCompiler::Predefined::Scalar.end()); // baseType must exist in the Scalar bag by program invariant. toReturn.m_underlyingScalar = static_cast( std::distance(AZ::ShaderCompiler::Predefined::Scalar.begin(), it) ); - toReturn.ResolveSize(); + toReturn.ResolveBaseSizeAndRank(); return toReturn; } + //! Create a new arithmetic traits promoted by necessity (through a binary operation usually) + //! of compatibility with a second operand of arithmetic typeclass. + //! For example: type(half{} + int{})->half + //! type(float3x3{} * double{})->double3x3 + //! And with columns & rows truncated to the smallest operand, + //! as part of the implicit necessary cast for operation compatibility. + inline ArithmeticTraits PromoteTruncateWith(Pair operands) + { + auto [_1, _2] = operands; + // put the scalar last (will be useful later if there is only one scalar operand) + SwapIf(_1, _2, _1.IsScalar()); + // now, let's construct the result in _1 + + if (_2.m_conversionRank > _1.m_conversionRank) // the higher ranking underlying wins; independently of full object size + { + _1.m_underlyingScalar = _2.m_underlyingScalar; + } + // cases: scalar-scalar : no dim change + // vecmat-scalar : no dim change, since result is dim(vecmat) which is in _1 + // scalar-vecmat : impossible (sorted by swap above) + // vecmat-vecmat : min(_1,_2) + if (!_1.IsScalar() && !_2.IsScalar()) + { + _1.m_rows = std::min(_1.m_rows, _2.m_rows); + _1.m_cols = std::min(_1.m_cols, _2.m_cols); + } + _1.ResolveBaseSizeAndRank(); + return _1; + } + MAKE_REFLECTABLE_ENUM(RootParamType, SRV, // t UAV, // u diff --git a/src/AzslcUtils.h b/src/AzslcUtils.h index 8886591..e07f8c1 100644 --- a/src/AzslcUtils.h +++ b/src/AzslcUtils.h @@ -958,15 +958,22 @@ namespace AZ::ShaderCompiler return UnqualifiedName{ctx->Name->getText()}; } - template - bool Is3ParentRuleOfType(antlr4::ParserRuleContext* ctx) + //! Get a pointer to the first parent that happens to be of type `SearchType` + //! with a limit depth of `maxDepth` parents to search through + template + SearchType DeepParentAs(tree::ParseTree* ctx, int maxDepth) { - if (ctx == nullptr || ctx->parent == nullptr || ctx->parent->parent == nullptr) // input canonicalization + if (auto* searchTypeNode = As(ctx)) { - return false; + return searchTypeNode; } - auto threeUp = ctx->parent->parent->parent; - return dynamic_cast(threeUp); + return maxDepth <= 0 || !ctx ? nullptr : DeepParentAs(ctx->parent, maxDepth - 1); + } + + template + bool Is3ParentRuleOfType(tree::ParseTree* ctx) + { + return DeepParentAs(ctx, 3); } // is def @@ -1040,6 +1047,13 @@ namespace AZ::ShaderCompiler return found ? found->argumentList() : nullptr; } + //! access the argument count at a function call site (from the AST) + inline size_t NumArgs(azslParser::FunctionCallExpressionContext* callCtx) + { + azslParser::ArgumentsContext* argsNode = callCtx->argumentList()->arguments(); + return argsNode ? argsNode->expression().size() : 0; + } + //! try to find a specific context type that this context would be a child of. template inline LookedUp* ExtractSpecificParent(antlr4::ParserRuleContext* ctx) diff --git a/src/GenericUtils.h b/src/GenericUtils.h index 378a214..627c466 100644 --- a/src/GenericUtils.h +++ b/src/GenericUtils.h @@ -24,9 +24,9 @@ namespace AZ { using runtime_error::runtime_error; }; - // Type-heterogeneity-preserving multi pointer object single visitor. - // Returns whatever the passed functor would. - // Throws if all passed objects are null. + //! Type-heterogeneity-preserving multi pointer object single visitor. + //! Returns whatever the passed functor would. + //! Throws if all passed objects are null. template invoke_result_t VisitFirstNonNull(Lambda functor, T* object) noexcept(false) { @@ -50,9 +50,9 @@ namespace AZ } } - // Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics. - // what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]` - // http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/ + //! Create substring views of views. Works like python slicing operator [n:m] with limited modulo semantics. + //! what I ultimately desire is the range v.3 feature eg `letters[{2,end-2}]` + //! http://ericniebler.com/2014/12/07/a-slice-of-python-in-c/ inline constexpr string_view Slice(const string_view& in, int64_t st, int64_t end) { auto inSSize = (int64_t)in.size(); @@ -107,8 +107,8 @@ namespace AZ //https://developercommunity.visualstudio.com/content/problem/275141/c2131-expression-did-not-evaluate-to-a-constant-fo.html } - // ability to create size_t literals - // waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo + //! ability to create size_t literals + //! waiting for Working Group to get their stuff together https://groups.google.com/a/isocpp.org/forum/#!topic/std-proposals/tGoPjUeHlKo inline constexpr std::size_t operator ""_sz(unsigned long long n) { return n; @@ -145,7 +145,7 @@ namespace AZ return fileName.substr(0, lastIndex); } - // debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library) + //! debug-build asserted dyn_cast, otherwise, release-build static_cast (idea from boost library) template To polymorphic_downcast(From instance) { @@ -157,7 +157,7 @@ namespace AZ return static_cast(instance); } - /// surround a string with a prefix and a suffix + //! surround a string with a prefix and a suffix inline string Decorate(string_view prefix, string_view body, string_view suffix) { std::stringstream ss; @@ -167,13 +167,13 @@ namespace AZ return ss.str(); } - /// 2 arguments version in case both sides are the same + //! 2 arguments version in case both sides are the same inline string Decorate(string_view prefixAndSuffix, string_view body) { return Decorate(prefixAndSuffix, body, prefixAndSuffix); } - /// reverse the effect of a symmetrical decoration + //! reverse the effect of a symmetrical decoration inline string_view Undecorate(string_view decoration, string_view body) { auto indexStart = StartsWith(body, decoration) ? decoration.length() : 0; @@ -181,7 +181,7 @@ namespace AZ return Slice(body, indexStart, indexEnd); } - // Erase-Remove algorithm which removes all whitespaces from a string. + //! Erase-Remove algorithm which removes all whitespaces from a string. inline string RemoveWhitespaces(string haystack) { haystack.erase(std::remove_if(haystack.begin(), haystack.end(), [](unsigned char c) {return std::isspace(c); }), haystack.end()); @@ -193,14 +193,14 @@ namespace AZ return std::all_of(s.begin(), s.end(), [&](char c) { return std::isspace(c); }); } - /// tells whether a position in a string is surrounded by round braces - /// e.g. true for arguments {"a(b)", 2} - /// e.g. true for arguments {"a()", 1} by convention - /// e.g. false for arguments {"a()", 2} by convention - /// e.g. false for arguments {"a(b)", 0} - /// e.g. false for arguments {"a(b)c", 4} - /// e.g. false for arguments {"a(b)c(d)", 4} - /// e.g. true for arguments {"a((b)c(d))", 5} + //! tells whether a position in a string is surrounded by round braces + //! e.g. true for arguments {"a(b)", 2} + //! e.g. true for arguments {"a()", 1} by convention + //! e.g. false for arguments {"a()", 2} by convention + //! e.g. false for arguments {"a(b)", 0} + //! e.g. false for arguments {"a(b)c", 4} + //! e.g. false for arguments {"a(b)c(d)", 4} + //! e.g. true for arguments {"a((b)c(d))", 5} inline bool WithinMatchedParentheses(string_view haystack, size_t charPosition) { const auto hayLen = haystack.length(); @@ -215,8 +215,8 @@ namespace AZ return nesting > 0; } - /// replace all occurrences of substring `sub` with substring `to` within haystack. - /// e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa" + //! replace all occurrences of substring `sub` with substring `to` within haystack. + //! e.g: Replace("aaa#aaa", "#", "_") gives-> "aaa_aaa" inline string Replace(string haystack, string_view sub, string_view to) { decltype(sub.length()) pos = 0; @@ -230,7 +230,7 @@ namespace AZ return haystack; } - // this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes. + //! this one is inspired by the docopt utilities. trims whitespace by default, but can be used to trim quotes. constexpr inline string_view Trim(string_view haystack, string_view toTrim = " \t\n") { const auto strEnd = haystack.find_last_not_of(toTrim); @@ -268,34 +268,56 @@ namespace AZ return std::find_if(begin, end, p) != end; } - /// argument in rangeV3-style version: + //! argument in rangeV3-style version: template< typename Container > string Join(const Container& c, string_view separator = "") { return Join(c.begin(), c.end(), separator); } - /// argument in rangeV3-style version: + //! argument in rangeV3-style version: template< typename Container, typename Predicate > bool Contains(const Container& c, Predicate p) { return Contains(c.begin(), c.end(), p); } - /// closest possible form of python's `in` keyword + //! closest possible form of python's `in` keyword template< typename Element, typename Container > bool IsIn(const Element& element, const Container& container) { return std::find(container.begin(), container.end(), element) != container.end(); } - /// generate a new container with copy-and-mutated elements + //! generate a new container with copy-and-mutated elements template< typename Container, typename ContainerOut, typename Functor > void TransformCopy(const Container& in, ContainerOut& out, Functor mutator) { std::transform(in.begin(), in.end(), std::back_inserter(out), mutator); } + enum class CopyIfPolicy { ForAll, InterruptAtFirstFalse }; + + //! inserts elements into the output iterator if they pass a predicate + template< typename InputIterator, typename Predicate, typename OutputIterator > + void CopyIf(InputIterator begin, InputIterator end, + Predicate pred, + OutputIterator out, + CopyIfPolicy policy) + { + for (auto it = begin; it != end; ++it) + { + if (pred(*it)) + { + *out = *it; + } + else if (policy == CopyIfPolicy::InterruptAtFirstFalse) + { + break; + } + } + } + inline string Unescape(string_view escapedText) { std::stringstream out; @@ -379,14 +401,14 @@ namespace AZ // Is-One-Of will check if a variable is equal to any of the values listed on the other parameters. // Example: IsOneOf(variableKind, Function, Enumeration) is short for: variableKind == Function || variableKind == Enumeration. // 2 arguments count: recursion terminal overload. - template - bool IsOneOf(T value, T tocheck) + template + bool IsOneOf(T value, U tocheck) { return value == tocheck; } // Any argument count version - template - bool IsOneOf(T value, T khead, Args... tail) + template + bool IsOneOf(T value, U khead, Args... tail) { return value == khead || IsOneOf(value, tail...); } @@ -531,7 +553,7 @@ namespace AZ #endif } - /// Log(N) find immediate lower element query in map-keys + //! Log(N) query to find the first immediately lower or equal element in a map's keys template< typename T, typename U > auto Infimum(map const& ctr, T query) { @@ -539,35 +561,114 @@ namespace AZ return it == ctr.begin() ? ctr.cend() : --it; } - /// Log(N) disjointed segments belong query - /// You can represent segments as you wish, as long as: - /// you provide the predicate to determine belonging. - /// map-keys are segment start points. - /// segments don't overlap. - /// returns: iterator to found interval key, or cend() + //! Log(N) disjointed segments belong query + //! You can represent segments as you wish, as long as: + //! you provide the predicate to determine belonging. + //! map-keys are segment start points. + //! segments don't overlap. + //! returns: iterator to found interval key, or cend() template< typename T, typename U, typename IntervalCheckPredicate > - auto FindInterval(const map& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate) + auto FindIntervalInDisjointSet(const map& ctr, const T& query, IntervalCheckPredicate&& isInIntervalPredicate) { auto inf = Infimum(ctr, query); - return inf == ctr.end() ? ctr.cend() : (isInIntervalPredicate(query, inf->second) ? inf : ctr.cend()); + bool isInInterval = inf != ctr.cend() && isInIntervalPredicate(query, inf->second); + return isInInterval ? inf : ctr.cend(); } - /// Log(N) disjointed segments belong query - /// segments are represented by their start points in the key, and last point in values. segments can't overlap. - /// returns: iterator to found interval key, or cend() + //! Log(N) disjointed segments belong query + //! segments are represented by their start points in the key, and last point in values. segments can't overlap. + //! returns: iterator to found interval key, or cend() template< typename T, typename U> - auto FindInterval(const map& ctr, const T& query) + auto FindIntervalInDisjointSet(const map& ctr, const T& query) { - return FindInterval(ctr, query, [](T q, U last) {return q <= last; }); + return FindIntervalInDisjointSet(ctr, query, [](T q, U last) {return q <= last; }); } + template< typename T > + struct Interval + { + bool IsEmpty() const { return b < a; } + bool operator== (Interval const& rhs) const { return a == rhs.a && b == rhs.b; } + bool operator< (Interval const& rhs) const { return a < rhs.a || (a == rhs.a && b < rhs.b); } + T a = (T)0; + T b = (T)-1; + }; + + //! In case of potential overlaps (not disjointed), this structure can support "is in" queries + template< typename T > + struct IntervalCollection + { + using IntervalT = Interval; + + void Add(IntervalT i) + { + m_obfirsts.emplace_back(i); + } + + //! doesn't respect RAII but for the sake of performance and convenience this is easier this way + void Seal() + { + m_oblasts = m_obfirsts; + std::sort(m_obfirsts.begin(), m_obfirsts.end(), [](auto i1, auto i2) + { + return i1.a < i2.a; + }); + std::sort(m_oblasts.begin(), m_oblasts.end(), [](auto i1, auto i2) + { + return i1.b < i2.b; + }); + m_sealed = true; + } + + //! Retrieve the subset of intervals activated by a point (query) + set GetIntervalsSurrounding(T query) const + { + assert(m_sealed); + // find the "set" of intervals starting before: + auto firstsSubEnd = std::lower_bound(m_obfirsts.begin(), m_obfirsts.end(), + query, + [=](auto interv, T q) { return interv.a <= q; }); + + // find the "set" of intervals ending after: + static vector endAfter; + endAfter.clear(); + CopyIf(m_oblasts.rbegin(), m_oblasts.rend(), // reverse iteration + [=](auto interv) { return interv.b >= query; }, + std::back_inserter(endAfter), + CopyIfPolicy::InterruptAtFirstFalse); + // for set_intersection to work, the less<> predicate has to work for both ranges + std::sort(endAfter.begin(), endAfter.end()); + + set result; + std::set_intersection(m_obfirsts.begin(), firstsSubEnd, + endAfter.begin(), endAfter.end(), + std::inserter(result, result.end())); + return result; + } + + //! Get the interval surrounding query that has the closest start point to query. + //! In case of an interval collection representing a tree, that is, + //! each overlapping interval is fully contained in the bigger one, + //! the closest start is guaranteed to be the most "leaf" interval. + //! This is useful for scopes. + IntervalT GetClosestIntervalSurrounding(T query) const + { + auto bag = GetIntervalsSurrounding(query); + return bag.empty() ? IntervalT{-1, -2} : *bag.rbegin(); + } + + vector m_obfirsts; // ordered by "firsts" + vector m_oblasts; // ordered by "lasts" + bool m_sealed = false; + }; + template< typename Deduced > decltype(auto) CastToRValueReference(Deduced&& value) { return static_cast&&>(value); } - /// add a missing operator for convenience and shortness of code + //! add a missing operator for convenience and shortness of code inline bool operator == (string_view lhs, char rhs) { return lhs.length() == 1 && lhs[0] == rhs; @@ -624,6 +725,15 @@ namespace AZ RemoveDuplicatesKeepOrder(lhs); } + //! Conditional swap algorithm + template + void SwapIf(T&& a, T&& b, bool condition) + { + if (condition) + { + std::swap(std::forward(a), std::forward(b)); + } + } } #ifndef NDEBUG @@ -704,10 +814,10 @@ namespace AZ::Tests assert(yellow == intervals.cend()); auto larger_than_all = Infimum(intervals, 15); assert(larger_than_all->first == 8); - assert(FindInterval(intervals, 4) != intervals.cend()); - assert(FindInterval(intervals, 6) == intervals.cend()); - assert(FindInterval(intervals, 8) != intervals.cend()); - assert(FindInterval(intervals, 1) == intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 4) != intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 6) == intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 8) != intervals.cend()); + assert(FindIntervalInDisjointSet(intervals, 1) == intervals.cend()); auto high = Infimum(intervals, 20); assert(high->first == 8); @@ -725,6 +835,23 @@ namespace AZ::Tests assert(IsIn("hibou", std::initializer_list{ "chouette", "hibou", "jay" })); assert(!IsIn("hibou", std::initializer_list{ "chouette", "jay" })); + + Interval intvs[] = {{0,10}, {1,5}, {3,3}, {7,9}, {12,15}}; + IntervalCollection ic; + std::for_each(std::begin(intvs), std::end(intvs), [&](auto i) {ic.Add(i); }); + ic.Seal(); + assert(ic.GetClosestIntervalSurrounding(-3).IsEmpty()); + assert((ic.GetClosestIntervalSurrounding(0) == Interval{0,10})); + assert((ic.GetClosestIntervalSurrounding(1) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(3) == Interval{3,3})); + assert((ic.GetClosestIntervalSurrounding(4) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(6) == Interval{0,10})); + assert((ic.GetClosestIntervalSurrounding(5) == Interval{1,5})); + assert((ic.GetClosestIntervalSurrounding(7) == Interval{7,9})); + assert((ic.GetClosestIntervalSurrounding(9) == Interval{7,9})); + assert(ic.GetClosestIntervalSurrounding(11).IsEmpty()); + assert((ic.GetClosestIntervalSurrounding(13) == Interval{12,15})); + assert(ic.GetClosestIntervalSurrounding(16).IsEmpty()); } } #endif diff --git a/src/MetaUtils.h b/src/MetaUtils.h index f946f65..3a7a69b 100644 --- a/src/MetaUtils.h +++ b/src/MetaUtils.h @@ -366,6 +366,10 @@ namespace AZ template class Op, class... Args> using DetectedOr_t = typename DetectedOr::type; + + //! define a Pair typealias that has two same T, without the need to repeat yourself + template + using Pair = std::pair; } #ifndef NDEBUG diff --git a/src/PadToAttributeMutator.cpp b/src/PadToAttributeMutator.cpp index 1697803..4aad68a 100644 --- a/src/PadToAttributeMutator.cpp +++ b/src/PadToAttributeMutator.cpp @@ -354,7 +354,7 @@ namespace AZ::ShaderCompiler else if (varInfo.GetTypeClass() == TypeClass::Enum) { auto* asClassInfo = m_ir.GetSymbolSubAs(varInfo.GetTypeId().GetName()); - size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.GetBaseSize(); + size = asClassInfo->Get()->m_underlyingType.m_arithmeticInfo.m_baseSize; } offset = Packing::PackNextChunk(layoutPacking, size, startAt); diff --git a/tests/Advanced/mae-methodcall.azsl b/tests/Advanced/mae-methodcall.azsl new file mode 100644 index 0000000..e11e09a --- /dev/null +++ b/tests/Advanced/mae-methodcall.azsl @@ -0,0 +1,27 @@ +ShaderResourceGroupSemantic slot0 +{ + FrequencyId = 1; + ShaderVariantFallback = 128; +}; +ShaderResourceGroup srg0 : slot0{} + +class C +{ + void f(double) { f(0) + f(0) * f(0) / f(0); f(0) - f(0) % f(0); } // cost 7*7+2 + void f(int) { ;;;;;;; } // cost 7 +}; + +option bool o; + +float4 main() +{ + if (o) // unnamed block $bk0 + { + C c; + { // unnamed block $bk1 to verify lookup capability to find `/main/$bk0/c` from `/main/$bk0/$bk1/` + // understand that `c`'s type is `/C`, and use /C scope to lookup the f() method. + // also deep expression on LHS of MAE to give no break to the typeof system + (c).f(2 * 5.0l); // double promotion in binary expression that resolves to double overload method call + } + } +} \ No newline at end of file diff --git a/tests/Advanced/mae-methodcall.py b/tests/Advanced/mae-methodcall.py new file mode 100644 index 0000000..4ee002f --- /dev/null +++ b/tests/Advanced/mae-methodcall.py @@ -0,0 +1,46 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +""" +Copyright (c) Contributors to the Open 3D Engine Project. +For complete copyright and license terms please see the LICENSE at the root of this distribution. + +SPDX-License-Identifier: Apache-2.0 OR MIT +""" +import sys +import os +sys.path.append("..") +sys.path.append("../..") +from clr import * +import testfuncs + + +def verifyOptionCosts(thefile, compilerPath, silent): + j, ok = testfuncs.buildAndGetJson(thefile, compilerPath, silent, ["--options"]) + if ok: + predicates = [] + # check all references of func() + predicates.append(lambda: j["ShaderOptions"][0]["name"] == "o") + predicates.append(lambda: j["ShaderOptions"][0]["costImpact"] == 54) + + if not silent: print (fg.CYAN+ style.BRIGHT+ "option expected cost check..."+ style.RESET_ALL) + ok = testfuncs.verifyAllPredicates(predicates, j) + return ok + +result = 0 # to define for sub-tests +resultFailed = 0 + +def doTests(compiler, silent, azdxcpath): + global result + global resultFailed + + # Working directory should have been set to this script's directory by the calling parent + # You can get it once doTests() is called, but not during initialization of the module, + # because at that time it will still be set to the working directory of the calling script + workDir = os.getcwd() + + if verifyOptionCosts(os.path.join(workDir, "mae-methodcall.azsl"), compiler, silent): result += 1 + else: resultFailed += 1 + + +if __name__ == "__main__": + print ("please call from testapp.py") diff --git a/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl b/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl index f9fc391..ea7770c 100644 --- a/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl +++ b/tests/Semantic/AsError/overload-resolution-impossible-and-heteroreturn.azsl @@ -2,12 +2,13 @@ struct A {}; struct B {}; A make(int); -B make(uint); +B make(float); void main() { float x = 0.5; - A a = make((int)floor(x) + 1); // #EC 41 + A a = make(floor(x)); // #EC 41 + // make((int)floor(x)); // help azslc knowing about unregistered functions by casting to force the type. } /*Semantic Error 41: line 10::14 '(10): unable to match arguments () to a registered overload. candidates are: /make(?int) diff --git a/tests/Semantic/typeof-keyword.azsl b/tests/Semantic/typeof-keyword.azsl index 47f34ac..bfdbd74 100644 --- a/tests/Semantic/typeof-keyword.azsl +++ b/tests/Semantic/typeof-keyword.azsl @@ -19,7 +19,7 @@ top gettop(); class A { - int a; + int a; }; class B : A @@ -83,16 +83,14 @@ void h() // NumericConstructorExpression float2(0,0) // LiteralExpression 42 // CommaExpression X, Y - // not supported: // PostfixUnaryExpression i++ // PrefixUnaryExpression ++i // BinaryExpression i + j - // e.g. typeof(1 + 3) = - // mathematics + // mathematics __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(1 + 3), __azslc_prtsym_least_qualified); - __azslc_print_message(" == ''\n"); + __azslc_print_message(" == 'int'\n"); // literals __azslc_print_message("@check predicate "); @@ -186,7 +184,7 @@ void h() __azslc_print_symbol(typeof(top::inner), __azslc_prtsym_fully_qualified); __azslc_print_message(" == '/top/inner'\n"); - // class inheritance: parent member access using scope resolution operator + // class inheritance: parent member access using scope resolution operator __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(B::a), __azslc_prtsym_least_qualified); __azslc_print_message(" == 'int'\n"); @@ -445,4 +443,110 @@ void h() __azslc_print_message("@check predicate "); __azslc_print_symbol(typeof(INTVAR, INTVAR, DOUBLEVAR), __azslc_prtsym_least_qualified); __azslc_print_message(" == 'double'\n"); + + // binary arithmetic expression + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'int'\n"); + + // with float promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1.f), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float'\n"); + + // with half promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1 + 1.h), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'half'\n"); + + // half and float to float promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.f + 1.h), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float'\n"); + + // int16_t to double promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.l + int16_t(1)), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double'\n"); + + // bool binary + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(1.l && int16_t(1)), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'bool'\n"); + + // lookedup + double d; + int64_t i64; + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d || i64), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'bool'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d - i64), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double'\n"); + + // vector scalar + float4 f4; + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(2.f * f4), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(f4 * 2.f), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(d * f4), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double4'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(f4 * d), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double4'\n"); + + // matrix scalar + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float() * float3x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float3x2'\n"); + + // matrix scalar with base type promotion + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(half3x2() * double()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x2'\n"); + + // truncations + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float3() * float2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float2'\n"); + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float4x4() * float2x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'float2x2'\n"); + + // truncate & upcast + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(float4x4() * double2x2()), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double2x2'\n"); + + // through alias + typealias d34m = double3x4; + typealias real = half; + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof((d34m)0 * (real)0), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x4'\n"); + // note: the parser takes d32m() as a function call, this is the "most verxing parse" problem + // float() is understood by the parser as a NumericConstructorExpression because it + // has a list of the tokens representing all fundamental types. + // but, with user defined identifiers, it can't branch into the "intended" context, + // because ALL(*) Antlr4 parsers recognizes context free grammar only. + // We could adopt a universal construction syntax Obj{} a la C++ for AZSL but it's not + // compliant with the philosophy of not deviating from HLSL, which on its side doesn't really + // accept constructor-type constructs. DXC just does it better here because clang parser is Turing complete. + + d34m d34var; + real rVar; + + __azslc_print_message("@check predicate "); + __azslc_print_symbol(typeof(rVar - d34var ), __azslc_prtsym_least_qualified); + __azslc_print_message(" == 'double3x4'\n"); }