Program Listing for File detection_post_process.h

Return to documentation for file (include/converter/include/ops/detection_post_process.h)

#ifndef MINDSPORE_CORE_OPS_DETECTION_POST_PROCESS_H_
#define MINDSPORE_CORE_OPS_DETECTION_POST_PROCESS_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindapi/base/format.h"
#include "mindapi/base/types.h"
#include "ops/base_operator.h"

namespace mindspore {
namespace ops {
constexpr auto kNameDetectionPostProcess = "DetectionPostProcess";
class MIND_API DetectionPostProcess : public BaseOperator {
 public:
  MIND_API_BASE_MEMBER(DetectionPostProcess);
  DetectionPostProcess() : BaseOperator(kNameDetectionPostProcess) {}
  void Init(const int64_t inputSize, const std::vector<float> &scale, const float NmsIouThreshold,
            const float NmsScoreThreshold, const int64_t MaxDetections, const int64_t DetectionsPerClass,
            const int64_t MaxClassesPerDetection, const int64_t NumClasses, const bool UseRegularNms,
            const bool OutQuantized, const Format &format = NCHW);
  //  scale:(h,w,x,y)
  void set_input_size(const int64_t inputSize);
  void set_scale(const std::vector<float> &scale);
  void set_nms_iou_threshold(const float NmsIouThreshold);
  void set_nms_score_threshold(const float NmsScoreThreshold);
  void set_max_detections(const int64_t MaxDetections);
  void set_detections_per_class(const int64_t DetectionsPerClass);
  void set_max_classes_per_detection(const int64_t MaxClassesPerDetection);
  void set_num_classes(const int64_t NumClasses);
  void set_use_regular_nms(const bool UseRegularNms);
  void set_out_quantized(const bool OutQuantized);
  void set_format(const Format &format);

  int64_t get_input_size() const;
  std::vector<float> get_scale() const;
  float get_nms_iou_threshold() const;
  float get_nms_score_threshold() const;
  int64_t get_max_detections() const;
  int64_t get_detections_per_class() const;
  int64_t get_max_classes_per_detection() const;
  int64_t get_num_classes() const;

  bool get_use_regular_nms() const;
  bool get_out_quantized() const;
  Format get_format() const;
};
MIND_API abstract::AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &,
                                                             const PrimitivePtr &primitive,
                                                             const std::vector<abstract::AbstractBasePtr> &input_args);
}  // namespace ops
}  // namespace mindspore

#endif  // MINDSPORE_CORE_OPS_DETECTION_POST_PROCESS_H_