mindspore.nn.probability.infer.SVI

class mindspore.nn.probability.infer.SVI(net_with_loss, optimizer)[source]

Stochastic Variational Inference(SVI).

Variational inference casts the inference problem as an optimization. Some distributions over the hidden variables are indexed by a set of free parameters, which are optimized to make distributions closest to the posterior of interest. For more details, refer to Variational Inference: A Review for Statisticians.

Parameters
  • net_with_loss (Cell) – Cell with loss function.

  • optimizer (Cell) – Optimizer for updating the weights.

Supported Platforms:

Ascend GPU

get_train_loss()[source]
Returns

numpy.dtype, the loss after training.

run(train_dataset, epochs=10)[source]

Optimize the parameters by training the probability network, and return the trained network.

Parameters
  • train_dataset (Dataset) – A training dataset iterator.

  • epochs (int) – Total number of iterations on the data. Default: 10.

Returns

Cell, the trained probability network.