Program Listing for File attention.h

Return to documentation for file (include/converter/include/ops/attention.h)

#ifndef MINDSPORE_CORE_OPS_ATTENTION_H_
#define MINDSPORE_CORE_OPS_ATTENTION_H_
#include <map>
#include <memory>
#include <string>
#include <vector>

#include "mindapi/base/types.h"
#include "ops/base_operator.h"

namespace mindspore {
namespace ops {
constexpr auto kNameAttention = "Attention";
class MIND_API Attention : public BaseOperator {
 public:
  MIND_API_BASE_MEMBER(Attention);
  Attention() : BaseOperator(kNameAttention) {
    InitIOName(
      {"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
      {"output"});
  }
  void Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross = false, float scale = 1.0f);
  void set_head_num(int64_t head_num);
  void set_head_size(int64_t head_size);
  void set_cross(bool cross);
  void set_position_bias(bool position_bias);
  void set_scale(float scale);
  int64_t get_head_num() const;
  int64_t get_head_size() const;
  bool get_cross() const;
  bool get_position_bias() const;
  float get_scale() const;
};
}  // namespace ops
}  // namespace mindspore
#endif  // MINDSPORE_CORE_OPS_ATTENTION_H_