Program Listing for File net.h

Return to documentation for file (include/runtime/include/api/net.h)

#ifndef MINDSPORE_INCLUDE_API_NET_H
#define MINDSPORE_INCLUDE_API_NET_H

#include <memory>
#include <vector>
#include <unordered_set>
#include <string>
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "include/api/cfg.h"

namespace mindspore {
#define REG(_name) Register(_name, #_name)

class Expr;
class NodeImpl;
class NetImpl;
class NodeSet;
class Graph;
class NetData;

class MS_API NetBase {
 public:
  NetBase() = default;
  virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
  virtual uint32_t type() = 0;
};

class MS_API Node : public NetBase {
 public:
  Node();
  virtual ~Node();

  Expr *Create(std::string name);
  std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
  uint32_t type() final;

 private:
  friend NodeImpl;
  std::shared_ptr<NodeImpl> impl_ = nullptr;
};

class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
 public:
  Net();
  virtual ~Net();
  explicit Net(std::string name);
  explicit Net(const Graph &g);

  virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);


  std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
  void Register(Net *net, std::string &&name);
  void Register(Node *node, std::string &&name);
  std::shared_ptr<NodeSet> trainable_params();
  virtual void Add(NetBase *element);
  const std::vector<int> InputShape(int idx);
  const std::vector<int> OutputShape(int idx);
  uint32_t type() final;

 private:
  friend NetImpl;
  friend NetData;
  std::shared_ptr<NetImpl> impl_;
};

class MS_API SoftMaxCrossEntropyCfg {
 public:
  std::string reduction = "mean";
};

class MS_API AdamConfig {
 public:
  float learning_rate_ = 1e-3;
  float beta1_ = 0.9;
  float beta2_ = 0.999;
  float eps_ = 1e-08;
  bool use_nesterov_ = false;
};

namespace NN {
MS_API Net *NetWithLoss(Net *net, Node *loss);
MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
                                   int fmt = NHWC);
};  // namespace NN
}  // namespace mindspore
#endif  // MINDSPORE_INCLUDE_API_NET_H