Lluvia
ComputeNode.h
Go to the documentation of this file.
1 
8 #ifndef LLUVIA_CORE_NODE_COMPUTE_NODE_H_
9 #define LLUVIA_CORE_NODE_COMPUTE_NODE_H_
10 
12 #include "lluvia/core/node/Node.h"
13 
14 #include <cstdint>
15 #include <map>
16 #include <memory>
17 #include <string>
18 
20 
21 namespace ll {
22 
23 namespace vulkan {
24  class Device;
25 } // namespace vulkan
26 
27 class Buffer;
28 class CommandBuffer;
29 class Image;
30 class ImageView;
31 class Interpreter;
32 class Object;
33 class Program;
34 
38 class ComputeNode : public Node, public std::enable_shared_from_this<ll::ComputeNode> {
39 
40 public:
41  ComputeNode() = delete;
42  ComputeNode(const ComputeNode& node) = delete;
43  ComputeNode(ComputeNode&& node) = delete;
44 
62  ComputeNode(const std::shared_ptr<ll::vulkan::Device>& device,
63  const ll::ComputeNodeDescriptor& descriptor,
64  const std::weak_ptr<ll::Interpreter>& interpreter);
65 
66  virtual ~ComputeNode();
67 
68  ComputeNode& operator=(const ComputeNode& node) = delete;
69  ComputeNode& operator=(ComputeNode&& node) = delete;
70 
71  ll::NodeType getType() const noexcept override;
72 
78  std::string getFunctionName() const noexcept;
79 
85  std::shared_ptr<ll::Program> getProgram() const noexcept;
86 
92  const ll::ComputeNodeDescriptor& getDescriptor() const noexcept;
93 
101  uint32_t getLocalX() const noexcept;
102 
110  uint32_t getLocalY() const noexcept;
111 
119  uint32_t getLocalZ() const noexcept;
120 
126  ll::vec3ui getLocalShape() const noexcept;
127 
135  uint32_t getGridX() const noexcept;
136 
151  void setGridX(const uint32_t x) noexcept;
152 
160  uint32_t getGridY() const noexcept;
161 
176  void setGridY(const uint32_t y) noexcept;
177 
185  uint32_t getGridZ() const noexcept;
186 
201  void setGridZ(const uint32_t z) noexcept;
202 
215  void setGridShape(const ll::vec3ui& shape) noexcept;
216 
222  void configureGridShape(const ll::vec3ui& globalShape) noexcept;
223 
231  ll::vec3ui getGridShape() const noexcept;
232 
233  bool hasPort(const std::string& name) const noexcept override;
234 
235  std::shared_ptr<ll::Object> getPort(const std::string& name) const override;
236 
237  void setPushConstants(const ll::PushConstants& constants) noexcept;
238 
239  const ll::PushConstants& getPushConstants() const noexcept;
240 
241  void bind(const std::string& name, const std::shared_ptr<ll::Object>& obj) override;
242 
243  void record(ll::CommandBuffer& commandBuffer) const override;
244 
245  void setParameter(const std::string& name, const ll::Parameter& value) override;
246 
247  const ll::Parameter& getParameter(const std::string& name) const override;
248 
249 protected:
250  void onInit() override;
251 
252 private:
253  void initPortBindings();
254  void initPipeline();
255 
256  void bindBuffer(const ll::PortDescriptor& port, const std::shared_ptr<ll::Buffer>& buffer);
257  void bindImageView(const ll::PortDescriptor& port, const std::shared_ptr<ll::ImageView>& imageView);
258 
259  std::vector<vk::DescriptorPoolSize> getDescriptorPoolSizes() const noexcept;
260  uint32_t countDescriptorType(const vk::DescriptorType type) const noexcept;
261 
262  std::shared_ptr<ll::vulkan::Device> m_device;
263 
264  vk::DescriptorSetLayout m_descriptorSetLayout;
265 
266  vk::PipelineLayout m_pipelineLayout;
267  vk::Pipeline m_pipeline;
268 
269  vk::DescriptorSet m_descriptorSet;
270  vk::DescriptorPool m_descriptorPool;
271 
272  ll::ComputeNodeDescriptor m_descriptor;
273 
274  std::vector<vk::DescriptorSetLayoutBinding> m_parameterBindings;
275 
276  // specialization constants
277  // vk::SpecializationInfo specializationInfo;
278  // std::vector<vk::SpecializationMapEntry> specializationMapEntries;
279  // // uint32_t local_x {1};
280 
281  std::map<std::string, std::shared_ptr<ll::Object>> m_objects;
282 
283  std::weak_ptr<ll::Interpreter> m_interpreter;
284 };
285 
286 } // namespace ll
287 
288 #endif /* LLUVIA_CORE_NODE_COMPUTE_NODE_H_ */
ComputeNodeDescriptor class.
Node class and related enums.
Objects to manage raw portions of allocated memory.
Definition: Buffer.h:57
Class for command buffer.
Definition: CommandBuffer.h:63
Class for describing a compute node node.
Definition: ComputeNodeDescriptor.h:33
Class representing compute nodes.
Definition: ComputeNode.h:38
void setGridX(const uint32_t x) noexcept
Sets the grid size in the X axis.
ComputeNode(const ComputeNode &node)=delete
std::shared_ptr< ll::Program > getProgram() const noexcept
Gets the program object associated to this node.
void bind(const std::string &name, const std::shared_ptr< ll::Object > &obj) override
Binds a ll::Object as port index for this node.
uint32_t getGridZ() const noexcept
Gets the grid size in Z dimension.
std::string getFunctionName() const noexcept
Gets the function name within the shader module this node runs.
ll::NodeType getType() const noexcept override
Gets the node type.
const ll::PushConstants & getPushConstants() const noexcept
ll::vec3ui getLocalShape() const noexcept
Gets the local group shape.
ComputeNode()=delete
void record(ll::CommandBuffer &commandBuffer) const override
Records the operations required to run this node in a ll::CommandBuffer.
ComputeNode & operator=(const ComputeNode &node)=delete
uint32_t getGridY() const noexcept
Gets the grid size in Y dimension.
void setGridY(const uint32_t y) noexcept
Sets the grid size in the Y axis.
ComputeNode(ComputeNode &&node)=delete
std::shared_ptr< ll::Object > getPort(const std::string &name) const override
Gets a port descriptor given its name.
void setGridZ(const uint32_t z) noexcept
Sets the grid size in the Z axis.
void setGridShape(const ll::vec3ui &shape) noexcept
Sets the grid shape.
void onInit() override
ComputeNode(const std::shared_ptr< ll::vulkan::Device > &device, const ll::ComputeNodeDescriptor &descriptor, const std::weak_ptr< ll::Interpreter > &interpreter)
Constructs the object.
bool hasPort(const std::string &name) const noexcept override
Returns whether or not a port exists with a given name.
uint32_t getLocalZ() const noexcept
Gets the local group size in Z dimension.
virtual ~ComputeNode()
void setPushConstants(const ll::PushConstants &constants) noexcept
ll::vec3ui getGridShape() const noexcept
Gets the grid shape.
uint32_t getGridX() const noexcept
Gets the grid size in X dimension.
uint32_t getLocalX() const noexcept
Gets the local group size in X dimension.
void setParameter(const std::string &name, const ll::Parameter &value) override
Sets a parameter.
const ll::Parameter & getParameter(const std::string &name) const override
Gets a parameter.
ComputeNode & operator=(ComputeNode &&node)=delete
const ll::ComputeNodeDescriptor & getDescriptor() const noexcept
Gets the descriptor.
uint32_t getLocalY() const noexcept
Gets the local group size in Y dimension.
void configureGridShape(const ll::vec3ui &globalShape) noexcept
Configures the grid shape given a global shape.
Represents portions of a ll::Image to be sent as parameter to a GLSL shader.
Definition: ImageView.h:149
Definition: Interpreter.h:34
Definition: Node.h:30
Base class for all types that can be used in computer shaders.
Definition: Object.h:94
Definition: Parameter.h:23
Definition: PortDescriptor.h:27
Class representing Vulkan shader modules in SPIR-V representation.
Definition: Program.h:37
Definition: PushConstants.h:20
Definition: Buffer.h:28
NodeType
Class for node type.
Definition: NodeType.h:19