Program Listing for File net.h
↰ Return to documentation for file (include/runtime/include/api/net.h
)
#ifndef MINDSPORE_INCLUDE_API_NET_H
#define MINDSPORE_INCLUDE_API_NET_H
#include <memory>
#include <vector>
#include <unordered_set>
#include <string>
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "include/api/cfg.h"
namespace mindspore {
#define REG(_name) Register(_name, #_name)
class Expr;
class NodeImpl;
class NetImpl;
class NodeSet;
class Graph;
class NetData;
class MS_API NetBase {
public:
NetBase() = default;
virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
virtual uint32_t type() = 0;
};
class MS_API Node : public NetBase {
public:
Node();
virtual ~Node();
Expr *Create(std::string name);
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
uint32_t type() final;
private:
friend NodeImpl;
std::shared_ptr<NodeImpl> impl_ = nullptr;
};
class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
public:
Net();
virtual ~Net();
explicit Net(std::string name);
explicit Net(const Graph &g);
virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
void Register(Net *net, std::string &&name);
void Register(Node *node, std::string &&name);
std::shared_ptr<NodeSet> trainable_params();
virtual void Add(NetBase *element);
const std::vector<int> InputShape(int idx);
const std::vector<int> OutputShape(int idx);
uint32_t type() final;
private:
friend NetImpl;
friend NetData;
std::shared_ptr<NetImpl> impl_;
};
class MS_API SoftMaxCrossEntropyCfg {
public:
std::string reduction = "mean";
};
class MS_API AdamConfig {
public:
float learning_rate_ = 1e-3;
float beta1_ = 0.9;
float beta2_ = 0.999;
float eps_ = 1e-08;
bool use_nesterov_ = false;
};
namespace NN {
MS_API Net *NetWithLoss(Net *net, Node *loss);
MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
int fmt = NHWC);
}; // namespace NN
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_NET_H