Skip to content

Commit

Permalink
[SPIRV] Implement out variables in patch control functions
Browse files Browse the repository at this point in the history
The documentation for the patch control functions only mentions

> The outputs are usually defined by a structure and is
identified by HS_CONSTANT_DATA_OUTPUT in this example; the structure
depends on the domain type and would be different for triangle or
isoline domains.

This is why we did not realize that we needed to implement out variables
as well. This commit implement the alternate method of having an output.

Fixes microsoft#3743
  • Loading branch information
s-perron committed Jan 10, 2024
1 parent 79f55f2 commit 5af3703
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 24 deletions.
65 changes: 44 additions & 21 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13409,25 +13409,6 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
const QualType pcfRetType = patchConstFunc->getReturnType();

std::vector<SpirvInstruction *> pcfParams;

// A lambda for creating a stage input variable and its associated temporary
// variable for function call. Also initializes the temporary variable using
// the contents loaded from the stage input variable. Returns the <result-id>
// of the temporary variable.
const auto createParmVarAndInitFromStageInputVar =
[this](const ParmVarDecl *param) {
const QualType type = param->getType();
std::string tempVarName = "param.var." + param->getNameAsString();
auto paramLoc = param->getLocation();
auto *tempVar = spvBuilder.addFnVar(
type, paramLoc, tempVarName, param->hasAttr<HLSLPreciseAttr>(),
param->hasAttr<HLSLNoInterpolationAttr>());
SpirvInstruction *loadedValue = nullptr;
declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
spvBuilder.createStore(tempVar, loadedValue, paramLoc);
return tempVar;
};

for (const auto *param : patchConstFunc->parameters()) {
// Note: According to the HLSL reference, the PCF takes an InputPatch of
// ControlPoints as well as the PatchID (PrimitiveID). This does not
Expand All @@ -13439,14 +13420,19 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
pcfParams.push_back(hullMainOutputPatch);
} else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
if (!primitiveId) {
primitiveId = createParmVarAndInitFromStageInputVar(param);
primitiveId = createPCFParmVarAndInitFromStageInputVar(param);
}
pcfParams.push_back(primitiveId);
} else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID)) {
if (!viewId) {
viewId = createParmVarAndInitFromStageInputVar(param);
viewId = createPCFParmVarAndInitFromStageInputVar(param);
}
pcfParams.push_back(viewId);
} else if (param->hasAttr<HLSLOutAttr>()) {
// Create a temporary function scope variable to pass to the PCF function
// for the output. The value of this variable should be copied to an
// output variable for the param after the function call.
pcfParams.push_back(createFunctionScopeTempFromParameter(param));
} else {
emitError("patch constant function parameter '%0' unknown",
param->getLocation())
Expand All @@ -13459,6 +13445,18 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
/*forPCF*/ true))
return false;

// Traverse all of the parameters for the patch constant function and copy out
// all of the output variables.
for (uint32_t idx = 0; idx < patchConstFunc->parameters().size(); idx++) {
const auto *param = patchConstFunc->parameters()[idx];
if (param->hasAttr<HLSLOutAttr>()) {
SpirvInstruction *pcfParam = pcfParams[idx];
SpirvInstruction *loadedValue = spvBuilder.createLoad(
pcfParam->getAstResultType(), pcfParam, param->getLocation());
declIdMapper.createStageOutputVar(param, loadedValue, /*forPCF*/ true);
}
}

spvBuilder.createBranch(mergeBB, locEnd);
spvBuilder.addSuccessor(mergeBB);
spvBuilder.setInsertPoint(mergeBB);
Expand Down Expand Up @@ -14419,6 +14417,31 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() {
spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation());
}

SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar(
const ParmVarDecl *param) {
const QualType type = param->getType();
std::string tempVarName = "param.var." + param->getNameAsString();
auto paramLoc = param->getLocation();
auto *tempVar = spvBuilder.addFnVar(
type, paramLoc, tempVarName, param->hasAttr<HLSLPreciseAttr>(),
param->hasAttr<HLSLNoInterpolationAttr>());
SpirvInstruction *loadedValue = nullptr;
declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
spvBuilder.createStore(tempVar, loadedValue, paramLoc);
return tempVar;
}

SpirvVariable *
SpirvEmitter::createFunctionScopeTempFromParameter(const ParmVarDecl *param) {
const QualType type = param->getType();
std::string tempVarName = "param.var." + param->getNameAsString();
auto paramLoc = param->getLocation();
auto *tempVar = spvBuilder.addFnVar(
type, paramLoc, tempVarName, param->hasAttr<HLSLPreciseAttr>(),
param->hasAttr<HLSLNoInterpolationAttr>());
return tempVar;
}

bool SpirvEmitter::spirvToolsTrimCapabilities(std::vector<uint32_t> *mod,
std::string *messages) {
spvtools::Optimizer optimizer(featureManager.getTargetEnv());
Expand Down
13 changes: 10 additions & 3 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,6 @@ class SpirvEmitter : public ASTConsumer {
/// statement.
void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt);

private:
/// Handles the offset argument in the given method call at the given argument
/// index. Panics if the argument at the given index does not exist. Writes
/// the <result-id> to either *constOffset or *varOffset, depending on the
Expand Down Expand Up @@ -1157,7 +1156,6 @@ class SpirvEmitter : public ASTConsumer {
const CXXMethodDecl *memberFn,
SourceLocation loc);

private:
/// \brief Takes a vector of size 4, and returns a vector of size 1 or 2 or 3
/// or 4. Creates a CompositeExtract or VectorShuffle instruction to extract
/// a scalar or smaller vector from the beginning of the input vector if
Expand Down Expand Up @@ -1195,7 +1193,6 @@ class SpirvEmitter : public ASTConsumer {
/// execution mode, if it has not already been added.
void beginInvocationInterlock(SourceLocation loc, SourceRange range);

private:
/// \brief If the given FunctionDecl is not already in the workQueue, creates
/// a FunctionInfo object for it, and inserts it into the workQueue. It also
/// updates the functionInfoMap with the proper mapping.
Expand Down Expand Up @@ -1241,6 +1238,16 @@ class SpirvEmitter : public ASTConsumer {
/// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html.
void addDerivativeGroupExecutionMode();

/// Creates an input variable for `param` that will be used by the patch
/// constant function. The parameter is also added to the patch constant
/// function. The wrapper function will copy the input variable to the
/// parameter.
SpirvVariable *
createPCFParmVarAndInitFromStageInputVar(const ParmVarDecl *param);

/// Returns a function scope parameter with the same type as |param|.
SpirvVariable *createFunctionScopeTempFromParameter(const ParmVarDecl *param);

public:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
Expand Down
39 changes: 39 additions & 0 deletions tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s

struct ControlPoint { float4 position : POSITION; };

// CHECK: %param_var_edge = OpVariable %_ptr_Function__arr_float_uint_3 Function
// CHECK: %param_var_inside = OpVariable %_ptr_Function_float Function
// CHECK: %param_var_myFloat = OpVariable %_ptr_Function_float Function
// CHECK: OpFunctionCall %void %HullConst %param_var_edge %param_var_inside %param_var_myFloat
// CHECK: [[edges:%[0-9]+]] = OpLoad %_arr_float_uint_3 %param_var_edge
// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_0
// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 0
// CHECK: OpStore [[addr]] [[val]]
// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_1
// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 1
// CHECK: OpStore [[addr]] [[val]]
// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_2
// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 2
// CHECK: OpStore [[addr]] [[val]]
// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_inside
// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelInner %uint_0
// CHECK: OpStore [[addr]] [[val]]
// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_myFloat
// CHECK: OpStore %out_var_MY_FLOAT [[val]]

void HullConst (out float edge [3] : SV_TessFactor, out float inside : SV_InsideTessFactor, out float myFloat : MY_FLOAT)
{
edge[0] = 2;
edge[1] = 2;
edge[2] = 2;
inside = 2;
myFloat = .2;
}

[domain("tri")]
[partitioning("fractional_odd")]
[outputtopology("triangle_ccw")]
[patchconstantfunc("HullConst")]
[outputcontrolpoints(3)]
ControlPoint Hull (InputPatch<ControlPoint,3> v, uint id : SV_OutputControlPointID) { return v[id]; }

0 comments on commit 5af3703

Please sign in to comment.