Program Listing for File opencl_runtime_wrapper.h

Return to documentation for file (include/opencl_runtime_wrapper.h)

#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H

#include <vector>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <type_traits>
#include "CL/cl2.hpp"
#include "include/api/allocator.h"
#include "include/api/status.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore::registry::opencl {
class OpenCLRuntimeWrapper {
 public:
  OpenCLRuntimeWrapper() = default;
  ~OpenCLRuntimeWrapper() = default;

  inline Status LoadSource(const std::string &program_name, const std::string &source);

  inline Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
                            const std::vector<std::string> &build_options_ext = {});

  Status SetKernelArg(const cl::Kernel &kernel, uint32_t index, void *const value);

  template <typename T>
  typename std::enable_if<!std::is_pointer<T>::value, Status>::type SetKernelArg(const cl::Kernel &kernel,
                                                                                 uint32_t index, const T value) {
    if (const_cast<cl::Kernel &>(kernel).setArg(index, value) != CL_SUCCESS) {
      return kLiteError;
    } else {
      return kSuccess;
    }
  }

  Status RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local,
                   cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr);

  Status SyncCommandQueue();

  void *MapBuffer(void *host_ptr, int flags, bool sync = true);

  Status UnmapBuffer(void *host_ptr);

  Status ReadImage(void *buffer, void *dst_data);

  Status WriteImage(void *buffer, void *src_data);

  std::shared_ptr<Allocator> GetAllocator();

  uint64_t DeviceMaxWorkGroupSize();

  uint64_t GetMaxImage2DWidth();

  uint64_t GetMaxImage2DHeight();

  uint64_t GetImagePitchAlignment();

 private:
  Status LoadSource(const std::vector<char> &program_name, const std::vector<char> &source);

  Status BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name, const std::vector<char> &kernel_name,
                     const std::vector<std::vector<char>> &build_options_ext);
};

Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
  return LoadSource(StringToChar(program_name), StringToChar(source));
}

Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
                                         const std::string &kernel_name,
                                         const std::vector<std::string> &build_options_ext) {
  return BuildKernel(kernel, StringToChar(program_name), StringToChar(kernel_name),
                     VectorStringToChar(build_options_ext));
}
}  // namespace mindspore::registry::opencl
#endif  // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H