Skip to content

Commit

Permalink
Merge MTDirectArguments and MTArgumentBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
andrejnau committed Jan 9, 2024
1 parent 9dff16e commit 703b904
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 230 deletions.
5 changes: 1 addition & 4 deletions src/FlyCube/BindingSet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ endif()

if (METAL_SUPPORT)
list(APPEND headers
MTArgumentBuffer.h
MTBindingSet.h
MTDirectArguments.mm
)
list(APPEND sources
MTArgumentBuffer.mm
MTDirectArguments.mm
MTBindingSet.mm
)
endif()

Expand Down
29 changes: 0 additions & 29 deletions src/FlyCube/BindingSet/MTArgumentBuffer.h

This file was deleted.

129 changes: 0 additions & 129 deletions src/FlyCube/BindingSet/MTArgumentBuffer.mm

This file was deleted.

20 changes: 18 additions & 2 deletions src/FlyCube/BindingSet/MTBindingSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@

#import <Metal/Metal.h>

class MTDevice;
class MTBindingSetLayout;
class Pipeline;

class MTBindingSet : public BindingSet {
public:
virtual void Apply(id<MTLRenderCommandEncoder> render_encoder, const std::shared_ptr<Pipeline>& state) = 0;
virtual void Apply(id<MTLComputeCommandEncoder> compute_encoder, const std::shared_ptr<Pipeline>& state) = 0;
MTBindingSet(MTDevice& device, const std::shared_ptr<MTBindingSetLayout>& layout);

void WriteBindings(const std::vector<BindingDesc>& bindings) override;

void Apply(id<MTLRenderCommandEncoder> render_encoder, const std::shared_ptr<Pipeline>& state);
void Apply(id<MTLComputeCommandEncoder> compute_encoder, const std::shared_ptr<Pipeline>& state);

private:
MTDevice& m_device;
std::shared_ptr<MTBindingSetLayout> m_layout;
std::map<std::pair<ShaderType, uint32_t>, id<MTLBuffer>> m_argument_buffers;
std::map<std::pair<ShaderType, uint32_t>, uint32_t> m_slots_count;
std::map<MTLResourceUsage, std::vector<id<MTLResource>>> m_compure_resouces;
std::map<std::pair<MTLRenderStages, MTLResourceUsage>, std::vector<id<MTLResource>>> m_graphics_resouces;
std::vector<BindKey> m_direct_bind_keys;
std::vector<BindingDesc> m_direct_bindings;
};
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
#include "BindingSet/MTDirectArguments.h"
#include "BindingSet/MTBindingSet.h"

#include "BindingSetLayout/MTBindingSetLayout.h"
#include "Device/MTDevice.h"
#include "HLSLCompiler/MSLConverter.h"
#include "Pipeline/MTPipeline.h"
#include "Shader/MTShader.h"
#include "View/MTView.h"

namespace {

MTLRenderStages GetStage(ShaderType type)
{
switch (type) {
case ShaderType::kPixel:
return MTLRenderStageFragment;
case ShaderType::kVertex:
return MTLRenderStageVertex;
default:
assert(false);
return 0;
}
}

template <typename CommandEncoderType>
constexpr bool is_compute_encoder()
{
Expand Down Expand Up @@ -100,24 +114,23 @@ void SetView(ShaderType shader_type, CommandEncoderType encoder, const std::shar
}
}

} // namespace

MTDirectArguments::MTDirectArguments(MTDevice& device, const std::shared_ptr<MTBindingSetLayout>& layout)
: m_device(device)
, m_layout(layout)
void ValidateRemappedSlots(const std::shared_ptr<Pipeline>& state, const std::vector<BindKey>& bind_keys)
{
}

void MTDirectArguments::WriteBindings(const std::vector<BindingDesc>& bindings)
{
m_bindings = bindings;
#ifndef NDEBUG
decltype(auto) program = state->As<MTPipeline>().GetProgram();
for (const auto& bind_key : bind_keys) {
decltype(auto) shader = program->GetShader(bind_key.shader_type);
uint32_t index = bind_key.GetRemappedSlot();
assert(index == shader->As<MTShader>().GetIndex(bind_key));
}
#endif
}

template <typename CommandEncoderType>
void MTDirectArguments::ApplyDirectArgs(CommandEncoderType encoder,
const std::vector<BindKey>& bind_keys,
const std::vector<BindingDesc>& bindings,
MTDevice& device)
void ApplyDirectArguments(CommandEncoderType encoder,
const std::vector<BindKey>& bind_keys,
const std::vector<BindingDesc>& bindings,
MTDevice& device)
{
for (const auto& binding : bindings) {
const BindKey& bind_key = binding.bind_key;
Expand Down Expand Up @@ -157,27 +170,98 @@ void SetView(ShaderType shader_type, CommandEncoderType encoder, const std::shar
}
}

void MTDirectArguments::ValidateRemappedSlots(const std::shared_ptr<Pipeline>& state,
const std::vector<BindKey>& bind_keys)
} // namespace

MTBindingSet::MTBindingSet(MTDevice& device, const std::shared_ptr<MTBindingSetLayout>& layout)
: m_device(device)
, m_layout(layout)
{
#ifndef NDEBUG
decltype(auto) program = state->As<MTPipeline>().GetProgram();
for (const auto& bind_key : bind_keys) {
decltype(auto) shader = program->GetShader(bind_key.shader_type);
if (!UseArgumentBuffers()) {
return;
}

const std::vector<BindKey>& bind_keys = m_layout->GetBindKeys();
for (BindKey bind_key : bind_keys) {
if (bind_key.space >= spirv_cross::kMaxArgumentBuffers || bind_key.count == ~0) {
m_direct_bind_keys.push_back(bind_key);
continue;
}
auto shader_space = std::make_pair(bind_key.shader_type, bind_key.space);
m_slots_count[shader_space] =
std::max(m_slots_count[shader_space], bind_key.GetRemappedSlot() + bind_key.count);
}
for (const auto& [shader_space, slots] : m_slots_count) {
m_argument_buffers[shader_space] = [m_device.GetDevice() newBufferWithLength:slots * sizeof(uint64_t)
options:MTLResourceStorageModeShared];
}
}

void MTBindingSet::WriteBindings(const std::vector<BindingDesc>& bindings)
{
if (!UseArgumentBuffers()) {
m_direct_bindings = bindings;
return;
}

m_direct_bindings.clear();
m_compure_resouces.clear();
m_graphics_resouces.clear();

for (const auto& binding : bindings) {
decltype(auto) bind_key = binding.bind_key;
if (bind_key.space >= spirv_cross::kMaxArgumentBuffers) {
if (bind_key.count != ~0) {
m_direct_bindings.push_back(binding);
}
continue;
}
decltype(auto) view = std::static_pointer_cast<MTView>(binding.view);
assert(view->GetViewDesc().view_type == bind_key.view_type);

uint32_t index = bind_key.GetRemappedSlot();
assert(index == shader->As<MTShader>().GetIndex(bind_key));
uint32_t slots = m_slots_count[{ bind_key.shader_type, bind_key.space }];
assert(index < slots);
uint64_t* arguments =
static_cast<uint64_t*>(m_argument_buffers[{ bind_key.shader_type, bind_key.space }].contents);
arguments[index] = view->GetGpuAddress();

id<MTLResource> resource = view->GetNativeResource();
if (!resource) {
continue;
}

MTLResourceUsage usage = view->GetUsage();
if (bind_key.shader_type == ShaderType::kCompute) {
m_compure_resouces[usage].push_back(resource);
} else {
m_graphics_resouces[{ GetStage(bind_key.shader_type), usage }].push_back(resource);
}
}
#endif
}

void MTDirectArguments::Apply(id<MTLRenderCommandEncoder> render_encoder, const std::shared_ptr<Pipeline>& state)
void MTBindingSet::Apply(id<MTLRenderCommandEncoder> render_encoder, const std::shared_ptr<Pipeline>& state)
{
ValidateRemappedSlots(state, m_layout->GetBindKeys());
ApplyDirectArgs(render_encoder, m_layout->GetBindKeys(), m_bindings, m_device);
for (const auto& [key, slots] : m_slots_count) {
SetBuffer(key.first, render_encoder, m_argument_buffers[key], 0, key.second);
}
for (const auto& [stages_usage, resources] : m_graphics_resouces) {
[render_encoder useResources:resources.data()
count:resources.size()
usage:stages_usage.second
stages:stages_usage.first];
}
ApplyDirectArguments(render_encoder, m_direct_bind_keys, m_direct_bindings, m_device);
}

void MTDirectArguments::Apply(id<MTLComputeCommandEncoder> compute_encoder, const std::shared_ptr<Pipeline>& state)
void MTBindingSet::Apply(id<MTLComputeCommandEncoder> compute_encoder, const std::shared_ptr<Pipeline>& state)
{
ValidateRemappedSlots(state, m_layout->GetBindKeys());
ApplyDirectArgs(compute_encoder, m_layout->GetBindKeys(), m_bindings, m_device);
for (const auto& [key, slots] : m_slots_count) {
SetBuffer(key.first, compute_encoder, m_argument_buffers[key], 0, key.second);
}
for (const auto& [usage, resources] : m_compure_resouces) {
[compute_encoder useResources:resources.data() count:resources.size() usage:usage];
}
ApplyDirectArguments(compute_encoder, m_direct_bind_keys, m_direct_bindings, m_device);
}
Loading

0 comments on commit 703b904

Please sign in to comment.