Skip to content

Commit

Permalink
Merge pull request #84 from o3de/auto-option-ranks
Browse files Browse the repository at this point in the history
Auto option ranks
  • Loading branch information
siliconvoodoo authored Apr 22, 2023
2 parents 8e399e0 + f7d55d4 commit 84e910f
Show file tree
Hide file tree
Showing 20 changed files with 829 additions and 141 deletions.
1 change: 1 addition & 0 deletions src/AzslcBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ namespace AZ::ShaderCompiler
// We reserve the right to change it in the future so we make it explicit attribute here
shaderOption["order"] = optionOrder;
optionOrder++;
shaderOption["costImpact"] = varInfo->m_estimatedCostImpact;

bool isUdt = IsUserDefined(varInfo->GetTypeClass());
assert(isUdt || IsPredefinedType(varInfo->GetTypeClass()));
Expand Down
1 change: 0 additions & 1 deletion src/AzslcEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,6 @@ namespace AZ::ShaderCompiler
const ICodeEmissionMutator* codeMutator = m_codeMutator;

ssize_t ii = interval.a;
bool wasInPreprocessorDirective = false; // record a state to detect exit of directives, because they need to reside on their own lines
while (ii <= interval.b)
{
auto* token = GetNextToken(ii /*inout*/);
Expand Down
4 changes: 1 addition & 3 deletions src/AzslcIntermediateRepresentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ namespace AZ::ShaderCompiler
cout << " storage: " << sub.m_typeInfoExt.m_qualifiers.GetDisplayName() << "\n";
cout << " array dim: \"" << sub.m_typeInfoExt.m_arrayDims.ToString() << "\"\n";
cout << " has sampler state: " << (sub.m_samplerState ? "yes\n" : "no\n");
cout << "\n";
if (!holds_alternative<monostate>(sub.m_constVal))
{
cout << " val: " << ExtractValueAsInt64(sub.m_constVal) << "\n";
Expand Down Expand Up @@ -519,7 +518,7 @@ namespace AZ::ShaderCompiler
if (varInfo.GetTypeClass() == TypeClass::Enum)
{
auto* asClassInfo = GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
}

nextMemberStartingOffset = Packing::PackNextChunk(layoutPacking, size, startAt);
Expand Down Expand Up @@ -960,5 +959,4 @@ namespace AZ::ShaderCompiler
}
return memberList[memberList.size() - 1];
}

} // end of namespace AZ::ShaderCompiler
1 change: 0 additions & 1 deletion src/AzslcIntermediateRepresentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

namespace AZ::ShaderCompiler
{

//! We limit the maximum number of render targets to 8, with indices in the range [0..7]
static const uint32_t kMaxRenderTargets = 8;

Expand Down
4 changes: 3 additions & 1 deletion src/AzslcKindInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ namespace AZ::ShaderCompiler
//! Get the size of a single element, ignoring array dimensions
const uint32_t GetSingleElementSize(Packing::Layout layout, bool defaultRowMajor) const
{
auto baseSize = m_coreType.m_arithmeticInfo.GetBaseSize();
auto baseSize = m_coreType.m_arithmeticInfo.m_baseSize;
bool isRowMajor = (m_mtxMajor == Packing::MatrixMajor::RowMajor ||
(m_mtxMajor == Packing::MatrixMajor::Default && defaultRowMajor));
auto rows = m_coreType.m_arithmeticInfo.m_rows;
Expand Down Expand Up @@ -399,6 +399,7 @@ namespace AZ::ShaderCompiler
ConstNumericVal m_constVal; // (attempted folded) initializer value for simple scalars
optional<SamplerStateDesc> m_samplerState;
ExtendedTypeInfo m_typeInfoExt;
int m_estimatedCostImpact = -1; //!< Cached value calculated by AnalyzeOptionRanks
};

// VarInfo methods definitions
Expand Down Expand Up @@ -791,6 +792,7 @@ namespace AZ::ShaderCompiler
vector< IdentifierUID > m_overrides; //!< list of implementing functions in child classes
optional< IdentifierUID > m_base; //!< points to the overridden function in the base interface, if applies. only supports one base
FunctionMultiForwards m_multiFwds = FMF_None; //!< presence of redundant prototype-only declarations
int m_costScore = -1; //!< heuristical static analysis of the amount of instructions contained
struct Parameter
{
IdentifierUID m_varId;
Expand Down
3 changes: 1 addition & 2 deletions src/AzslcMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ namespace StdFs = std::filesystem;
// For large features or milestones. Minor version allows for breaking changes. Existing tests can change.
#define AZSLC_MINOR "8" // last change: introduction of class inheritance
// For small features or bug fixes. They cannot introduce breaking changes. Existing tests shouldn't change.
#define AZSLC_REVISION "17" // last change: fixup alignment check logic_error because of lack of an inter-scope check limiter.
// "16" change: fixup runtime error with redundant function declarations
#define AZSLC_REVISION "18" // last change: automatic option ranks

namespace AZ::ShaderCompiler
{
Expand Down
242 changes: 223 additions & 19 deletions src/AzslcReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ namespace AZ::ShaderCompiler
else if (varInfo.GetTypeClass() == TypeClass::Enum)
{
auto* asClassInfo = m_ir->GetSymbolSubAs<ClassInfo>(varInfo.GetTypeId().GetName());
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.GetBaseSize();
size = asClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.m_baseSize;
}

offset = Packing::PackNextChunk(layoutPacking, size, startAt);
Expand Down Expand Up @@ -629,7 +629,9 @@ namespace AZ::ShaderCompiler

void CodeReflection::DumpVariantList(const Options& options) const
{
AnalyzeOptionRanks();
m_out << GetVariantList(options);
m_out << "\n";
}

static void ReflectBinding(Json::Value& output, const RootSigDesc::SrgParamDesc& bindInfo)
Expand Down Expand Up @@ -857,11 +859,12 @@ namespace AZ::ShaderCompiler
for (auto& seenat : kindInfo->GetSeenats())
{
assert(uid == seenat.m_referredDefinition);
// TODO: the assumption that intervals where distinct doesnt hold anymore now that we have unnamed scopes
auto intervalIter = FindInterval(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
{
return value.first.properlyContains({key, key});
});
// careful of the invariant: distinct intervals. (can't support functions nested in functions nor imbricated block scopes)
// ok for now because AZSL/HLSL don't have lambdas
auto intervalIter = FindIntervalInDisjointSet(scopes, seenat.m_where.m_focusedTokenId, [](ssize_t key, auto& value)
{
return value.first.properlyContains({key, key});
});
if (intervalIter != scopes.cend())
{
const IdentifierUID& encloser = intervalIter->second.second;
Expand Down Expand Up @@ -909,16 +912,9 @@ namespace AZ::ShaderCompiler
uint32_t numOf32bitConst = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID);
RootSigDesc rootSignature = BuildSignatureDescription(options, numOf32bitConst);

// prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
MapOfBeginToSpanAndUid scopeStartToFunctionIntervals;
for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
{
if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types. We need a set of disjoint intervals as an invariant for the next algorithm.
{
// the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
scopeStartToFunctionIntervals[interval.a] = std::make_pair(interval, uid);
}
}
// Prepare a lookup acceleration data structure for reverse mapping tokens to scopes.
// (truth: we need a set of disjoint intervals as an invariant for the following algorithm)
GenerateTokenScopeIntervalToUidReverseMap();

Json::Value srgRoot(Json::objectValue);
// Order the reflection by SRG for convenience
Expand Down Expand Up @@ -968,7 +964,7 @@ namespace AZ::ShaderCompiler
else
{
set<IdentifierUID> dependencyList;
DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(srgParam.m_uid, dependencyList, m_functionIntervals);
srgMember[srgParam.m_uid.GetNameLeaf()] = makeJsonNodeForOneResource(dependencyList, srgParam, {});
}
}
Expand All @@ -981,7 +977,7 @@ namespace AZ::ShaderCompiler
for (auto& srgConstant : srgInfo->m_implicitStruct.GetMemberFields())
{
allConstants.append({ srgConstant.GetNameLeaf() });
DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(srgConstant, dependencyList, m_functionIntervals);
}
// variant fallback support
if (srgInfo->m_shaderVariantFallback)
Expand All @@ -992,7 +988,7 @@ namespace AZ::ShaderCompiler
{
if (varSub->CheckHasStorageFlag(StorageFlag::Option))
{
DiscoverTopLevelFunctionDependencies(varUid, dependencyList, scopeStartToFunctionIntervals);
DiscoverTopLevelFunctionDependencies(varUid, dependencyList, m_functionIntervals);
}
}
}
Expand All @@ -1004,4 +1000,212 @@ namespace AZ::ShaderCompiler

m_out << srgRoot;
}

// Helper routine for option rank analysis
static int GuesstimateIntrinsicFunctionCost(string_view funcName)
{
if (IsOneOf(funcName, "CallShader", "TraceRay"))
{ // non measurable but assumed high
return 100;
}
else if (IsOneOf(funcName, "Sample", "Load", "InterlockedCompareStore", "InterlockedCompareExchange", "InterlockedExchange", "Append"))
{ // memory access, locked or not, will have high latency
return 10;
}
else
{ // unlisted intrinsics like lerp, log2, cos, distance.. will default to a cost of 1.
return 1;
}
}

// Helper routine for option rank analysis. When picking AN overload is more useful than forfeiting.
// The function GetConcreteFunctionThatMatchesArgumentList forfeits when the overloadset contains
// strictly more than 1 concrete function with the queried arity. In our case, we prefer to just pick any.
static IdentifierUID PickAnyOverloadThatMatchesArgCount(IntermediateRepresentation* ir,
azslParser::FunctionCallExpressionContext* callNode,
KindInfo& overload)
{
IdentifierUID concrete;
size_t numArgs = NumArgs(callNode);
overload.GetSubAs<OverloadSetInfo>()->AnyOf(
[&](IdentifierUID const& uid)
{
auto* concreteFcInfo = ir->GetSymbolSubAs<FunctionInfo>(uid.GetName());
size_t numParams = concreteFcInfo->GetParameters(true).size();
if (numParams == numArgs)
{
concrete = uid; // we write the result through reference capture (not clean but convenient)
return true;
}
return false;
});
return concrete;
}

void CodeReflection::AnalyzeOptionRanks() const
{
// make sure we have the scope lookup cache ready
GenerateTokenScopeIntervalToUidReverseMap();
// loop over variables
for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
{
// only options
if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
{
int impactScore = 0;
// loop over appearances over the program
for (Seenat& ref : kindInfo->GetSeenats())
{
// determine an impact score
impactScore += AnalyzeImpact(ref.m_where) // dependent code that may be skipped depending on the value of that ref
+ 1; // by virtue of being mentioned (seenat), we count the reference as an access of cost 1.
}
varInfo->m_estimatedCostImpact = impactScore;
}
}
}

int CodeReflection::AnalyzeImpact(TokensLocation const& location) const
{
// find the node at `location`:
ParserRuleContext* node = m_ir->m_tokenMap.GetNode(location.m_focusedTokenId);
// go up tree to meet a block node that has visitable depth:
// can be any of if/for/while/switch
// 4 is an arbitrary depth, enough to search up things like `for (a, b<(ref+1), c)` binaryop->braces->cmpexpr->cond->for
if (auto* whileNode = DeepParentAs<azslParser::WhileStatementContext*>(node->parent, 3))
{
node = whileNode->embeddedStatement();
}
else if (auto* ifNode = DeepParentAs<azslParser::IfStatementContext*>(node->parent, 3))
{
node = ifNode->embeddedStatement();
}
else if (auto* forNode = DeepParentAs<azslParser::ForStatementContext*>(node->parent, 4))
{
node = forNode->embeddedStatement();
}
else if (auto* switchNode = DeepParentAs<azslParser::SwitchStatementContext*>(node->parent, 3))
{
node = switchNode->switchBlock();
}
int score = 0;
AnalyzeImpact(node, score);
return score;
}

void CodeReflection::AnalyzeImpact(ParserRuleContext* astNode, int& scoreAccumulator) const
{
for (auto& c : astNode->children)
{
if (auto* callNode = As<azslParser::FunctionCallExpressionContext*>(c))
{
// branch into an overload specialized for function lookup:
AnalyzeImpact(callNode, scoreAccumulator);
}
else if (auto* node = As<ParserRuleContext*>(c))
{
AnalyzeImpact(node, scoreAccumulator); // recurse down to make sure to capture embedded calls, like e.g. "x ? f() : 0;"
}
if (auto* leaf = As<tree::TerminalNode*>(c))
{
// determine cost by number of full expressions separated by semicolon
scoreAccumulator += leaf->getSymbol()->getType() == azslLexer::Semi; // bool as 0 or 1 trick
}
}
}

void CodeReflection::AnalyzeImpact(azslParser::FunctionCallExpressionContext* callNode, int& scoreAccumulator) const
{
// to access the function symbol info we need the current scope, the function call name and perform a lookup.

// figure out the scope at this token.
// theoretically should be something in the like of the body of another function,
// or an anonymous block within another function.
auto interval = m_intervals.GetClosestIntervalSurrounding(callNode->start->getTokenIndex());
if (!interval.IsEmpty())
{
IdentifierUID encloser = m_intervalToUid[interval];

// Because we are past the end of the semantic analysis,
// the scope tracker is registering the last seen scope (surely "/").
// This is a stateful side-effect system unfortunately, and since we'll call
// some feature of the semantic orchestrator (like TypeofExpr) we need to hack
// the scope tracker:
m_ir->m_sema.m_scope->m_currentScopePath = encloser.GetName();
m_ir->m_sema.m_scope->UpdateCurScopeUID();

QualifiedName startupLookupScope = encloser.GetName();
UnqualifiedName funcName;
if (auto* idExpr = As<azslParser::IdentifierExpressionContext*>(callNode->Expr))
{
funcName = ExtractNameFromIdExpression(idExpr->idExpression());
}
else if (auto* maeExpr = As<AstMemberAccess*>(callNode->Expr))
{
startupLookupScope = m_ir->m_sema.TypeofExpr(maeExpr->LHSExpr);
funcName = ExtractNameFromIdExpression(maeExpr->Member);
}
IdAndKind* overload = m_ir->m_symbols.LookupSymbol(startupLookupScope, funcName);
if (!overload) // in case of function not found, we assume it's an intrinsic.
{
scoreAccumulator += GuesstimateIntrinsicFunctionCost(funcName);
}
else
{
azslParser::ArgumentListContext* args = GetArgumentListIfBelongsToFunctionCall(callNode);
IdAndKind* symbolMeantUnderCallNode = m_ir->m_sema.ResolveOverload(overload, args);
IdentifierUID concrete;
if (!symbolMeantUnderCallNode || m_ir->GetKind(symbolMeantUnderCallNode->first) == Kind::OverloadSet)
{ // in case of strict selection failure, run a fuzzy select
concrete = PickAnyOverloadThatMatchesArgCount(m_ir, callNode, overload->second);
// if still not enough to get a fix (concrete=={}), it might be an ill-formed input. prefer to forfeit
}
else
{
concrete = symbolMeantUnderCallNode->first;
}

if (auto* funcInfo = m_ir->GetSymbolSubAs<FunctionInfo>(concrete.GetName()))
{
if (funcInfo->m_costScore == -1) // cost not yet discovered for this function
{
funcInfo->m_costScore = 0;
using AstFDef = azslParser::HlslFunctionDefinitionContext;
AnalyzeImpact(polymorphic_downcast<AstFDef*>(funcInfo->m_defNode->parent)->block(),
funcInfo->m_costScore); // recurse and cache
}
scoreAccumulator += funcInfo->m_costScore;
}
}
// other cases forfeited for now, but that would at least include things like eg braces (f)()
}
else // no interval found
{
// function calls outside of function bodies can appear in an initializer:
// int g_a = MakeA(); // global init
// class C { int m_a = CompA(); // constructor init (invalid AZSL/HLSL)
// class D { void Method(int a_a = DefaultA()); // default parameter value
// in any case, extracting the scope is impossible with this system.
// we forfeit evaluation of a score
}
}

void CodeReflection::GenerateTokenScopeIntervalToUidReverseMap() const
{
if (m_functionIntervals.empty())
{
for (auto& [uid, interval] : m_ir->m_scope.m_scopeIntervals)
{
if (m_ir->GetKind(uid) == Kind::Function) // Filter out unnamed blocs and types.
{
// the reason to choose .a as the key is so we can query using Infimum (sort of lower_bound)
m_functionIntervals[interval.a] = std::make_pair(interval, uid);
}
auto i = Interval<ssize_t>{interval.a, interval.b};
m_intervals.Add(i);
m_intervalToUid[i] = uid;
}
m_intervals.Seal();
}
}
}
Loading

0 comments on commit 84e910f

Please sign in to comment.