Program Listing for File lstm.h

Return to documentation for file (include/converter/include/ops/lstm.h)

#ifndef MINDSPORE_CORE_OPS_LSTM_H_
#define MINDSPORE_CORE_OPS_LSTM_H_

#include <algorithm>
#include <map>
#include <memory>
#include <string>

#include "mindapi/base/types.h"
#include "ops/base_operator.h"

namespace mindspore {
namespace ops {
constexpr auto kNameLSTM = "LSTM";
class MIND_API LSTM : public BaseOperator {
 public:
  MIND_API_BASE_MEMBER(LSTM);
  LSTM() : BaseOperator(kNameLSTM) {}
  void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
            const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f,
            const float zoneout_hidden = 0.0f);
  void set_input_size(const int64_t input_size);
  int64_t get_input_size() const;
  void set_hidden_size(const int64_t hidden_size);
  int64_t get_hidden_size() const;
  void set_num_layers(const int64_t num_layers);
  int64_t get_num_layers() const;
  void set_has_bias(const bool has_bias);
  bool get_has_bias() const;
  void set_dropout(const float dropout);
  float get_dropout() const;
  void set_bidirectional(const bool bidirectional);
  bool get_bidirectional() const;
  void set_num_directions(const int64_t num_directions);
  int64_t get_num_directions() const;
  void set_zoneout_cell(float zoneout_cell);
  float get_zoneout_cell() const;
  void set_zoneout_hidden(float zoneout_hidden);
  float get_zoneout_hidden() const;
  int64_t get_good_ld(const int64_t dim, const int64_t type_size);
};
}  // namespace ops
}  // namespace mindspore

#endif  // MINDSPORE_CORE_OPS_LSTM_H_