Program Listing for File attention.h
↰ Return to documentation for file (include/converter/include/ops/attention.h
)
#ifndef MINDSPORE_CORE_OPS_ATTENTION_H_
#define MINDSPORE_CORE_OPS_ATTENTION_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindapi/base/types.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAttention = "Attention";
class MIND_API Attention : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Attention);
Attention() : BaseOperator(kNameAttention) {
InitIOName(
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
{"output"});
}
void Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross = false, float scale = 1.0f);
void set_head_num(int64_t head_num);
void set_head_size(int64_t head_size);
void set_cross(bool cross);
void set_position_bias(bool position_bias);
void set_scale(float scale);
int64_t get_head_num() const;
int64_t get_head_size() const;
bool get_cross() const;
bool get_position_bias() const;
float get_scale() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ATTENTION_H_