{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "abf0f151",
   "metadata": {},
   "source": [
    "# Tensor索引支持\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.1/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.1/zh_cn/note/mindspore_index_support.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.1/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.1/zh_cn/note/mindspore_index_support.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.1/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/source_zh_cn/note/index_support.ipynb)\n",
    "\n",
    "Tensor 支持单层与多层索引取值,赋值以及增强赋值,支持动态图(PyNative)以及静态图(Graph)模式。\n",
    "\n",
    "## 索引取值\n",
    "\n",
    "索引值支持`int`、`bool`、`None`、`ellipsis`、`slice`、`Tensor`、`List`、`Tuple`。\n",
    "\n",
    "- `int`索引取值\n",
    "\n",
    "  支持单层和多层`int`索引取值,单层`int`索引取值:`tensor_x[int_index]`,多层`int`索引取值:`tensor_x[int_index0][int_index1]...`。\n",
    "\n",
    "  `int`索引取值操作的是第零维,索引值小于第零维长度,在取出第零维对应位置数据后,会消除第零维。\n",
    "\n",
    "  例如,如果对一个`shape`为`(3, 4, 5)`的tensor进行单层`int`索引取值,取得结果的`shape`是`(4, 5)`。\n",
    "\n",
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`int`索引取值。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7dc93570",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[0 1]\n",
      " [2 3]\n",
      " [4 5]]\n",
      "data_multi:\n",
      "[2 3]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(2 * 3 * 2).reshape((2, 3, 2)))\n",
    "data_single = tensor_x[0]\n",
    "data_multi = tensor_x[0][1]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a4410aa",
   "metadata": {},
   "source": [
    "- `bool`索引取值\n",
    "\n",
    "  支持单层和多层`bool`索引取值,单层`bool`索引取值:`tensor_x[True]`,多层`bool`索引取值:`tensor_x[True][True]...`。\n",
    "\n",
    "  `bool`索引取值操作的是第零维,在取出所有数据后,会在`axis=0`轴上扩展一维,对应`True`/`False`,该维长度分别为1/0。`False`将会在`shape`中引入`0`,因此暂只支持`True`。\n",
    "\n",
    "  例如,对一个`shape`为`(3, 4, 5)`的tensor进行单层`True`索引取值,取得结果的`shape`是`(1, 3, 4, 5)`。\n",
    "\n",
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`bool`索引取值。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1c98f2b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[[0 1 2]\n",
      "  [3 4 5]]]\n",
      "data_multi:\n",
      "[[[[0 1 2]\n",
      "   [3 4 5]]]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(2 * 3).reshape((2, 3)))\n",
    "data_single = tensor_x[True]\n",
    "data_multi = tensor_x[True][True]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e3c1be2",
   "metadata": {},
   "source": [
    "- `None`索引取值\n",
    "\n",
    "  `None`索引取值和`True`索引取值一致,可参考`True`索引取值,这里不再赘述。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0272c3cd",
   "metadata": {},
   "source": [
    "- `ellipsis`索引取值\n",
    "\n",
    "  支持单层和多层`ellipsis`索引取值,单层`ellipsis`索引取值:`tensor_x[...]`,多层`ellipsis`索引取值:`tensor_x[...][...]...`。\n",
    "\n",
    "  `ellipsis`索引取值操作在所有维度上取出所有数据。一般多作为`Tuple`索引的组成元素,`Tuple`索引将于下面介绍。\n",
    "\n",
    "  例如,对一个`shape`为`(3, 4, 5)`的tensor进行`ellipsis`索引取值,取得结果的`shape`依然是`(3, 4, 5)`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0bcdc82b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[0 1 2]\n",
      " [3 4 5]]\n",
      "data_multi:\n",
      "[[0 1 2]\n",
      " [3 4 5]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(2 * 3).reshape((2, 3)))\n",
    "data_single = tensor_x[...]\n",
    "data_multi = tensor_x[...][...]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e187fc1",
   "metadata": {},
   "source": [
    "- `slice`索引取值\n",
    "\n",
    "  支持单层和多层`slice`索引取值,单层`slice`索引取值:`tensor_x[slice_index]`,多层`slice`索引取值:`tensor_x[slice_index0][slice_index1]...`。\n",
    "\n",
    "  `slice`索引取值操作的是第零维,取出第零维所切到位置的元素,`slice`不会降维,即使切到长度为1,区别于`int`索引取值。\n",
    "\n",
    "  例如,`tensor_x[0:1:1] != tensor_x[0]`,因为`shape_former = (1,) + shape_latter`。\n",
    "\n",
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`slice`索引取值。\n",
    "\n",
    "  `slice`有`start`、`stop`和`step`组成。`start`默认值为0,`stop`默认值为该维长度,`step`默认值为1。\n",
    "\n",
    "  例如,`tensor_x[:] == tensor_x[0:length:1]`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e0bd19d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[[ 4  5]\n",
      "  [ 6  7]]\n",
      "\n",
      " [[12 13]\n",
      "  [14 15]]]\n",
      "data_multi:\n",
      "[[[12 13]\n",
      "  [14 15]]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(4 * 2 * 2).reshape((4, 2, 2)))\n",
    "data_single = tensor_x[1:4:2]\n",
    "data_multi = tensor_x[1:4:2][1:]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62de31d7",
   "metadata": {},
   "source": [
    "- `Tensor`索引取值\n",
    "\n",
    "  支持单层和多层`Tensor`索引取值,单层`Tensor`索引取值:`tensor_x[tensor_index]`,多层`Tensor`索引取值:`tensor_x[tensor_index0][tensor_index1]...`。\n",
    "\n",
    "  `Tensor`索引取值操作的是第零维,取出第零维对应位置的元素。\n",
    "\n",
    "  索引`Tensor`数据类型支持int型和bool型。\n",
    "\n",
    "  当数据类型是int型时,可以是(int8, int16, int32, int64),值必须为非负数,且小于第零维长度。\n",
    "\n",
    "  `Tensor`索引取值得到结果的`data_shape = tensor_index.shape + tensor_x.shape[1:]`。\n",
    "\n",
    "  例如,对一个`shape`为`(6, 4, 5)`的tensor通过`shape`为`(2, 3)`的tensor进行索引取值,取得结果的`shape`为`(2, 3, 4, 5)`。\n",
    "\n",
    "  当数据类型是bool型时,`Tensor`索引取值得到结果的维度是 `tensor_x.ndim - tensor_index.ndim + 1`。\n",
    "\n",
    "  设 `tensor_index` 中True的数量是 `num_true` ,`tensor_x` 的shape是 `(N0, N1, ... Ni-1, Ni, Ni+1, ..., Nk)`, `tensor_index` 的shape是 `(N0, N1, ... Ni-1)`, 则返回值的shape是 `(num_true, Ni+1, Ni+2, ... , Nk)` 。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dd101fc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3]\n"
     ]
    }
   ],
   "source": [
    "from mindspore import dtype as mstype\n",
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor([1, 2, 3])\n",
    "tensor_index = ms.Tensor([True, False, True], dtype=mstype.bool_)\n",
    "output = tensor_x[tensor_index]\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aad4d54",
   "metadata": {},
   "source": [
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`Tensor`索引取值。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ae922eae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[[[ 6  7  8]\n",
      "   [ 9 10 11]]\n",
      "\n",
      "  [[12 13 14]\n",
      "   [15 16 17]]]\n",
      "\n",
      "\n",
      " [[[ 0  1  2]\n",
      "   [ 3  4  5]]\n",
      "\n",
      "  [[18 19 20]\n",
      "   [21 22 23]]]]\n",
      "data_multi:\n",
      "[[[[[ 6  7  8]\n",
      "    [ 9 10 11]]\n",
      "\n",
      "   [[12 13 14]\n",
      "    [15 16 17]]]\n",
      "\n",
      "\n",
      "  [[[ 6  7  8]\n",
      "    [ 9 10 11]]\n",
      "\n",
      "   [[12 13 14]\n",
      "    [15 16 17]]]]]\n"
     ]
    }
   ],
   "source": [
    "from mindspore import dtype as mstype\n",
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(4 * 2 * 3).reshape((4, 2, 3)))\n",
    "tensor_index0 = ms.Tensor(np.array([[1, 2], [0, 3]]), mstype.int32)\n",
    "tensor_index1 = ms.Tensor(np.array([[0, 0]]), mstype.int32)\n",
    "data_single = tensor_x[tensor_index0]\n",
    "data_multi = tensor_x[tensor_index0][tensor_index1]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8608b34",
   "metadata": {},
   "source": [
    "- `List`索引取值\n",
    "\n",
    "  支持单层和多层`List`索引取值,单层`List`索引取值:`tensor_x[list_index]`,多层`List`索引取值:`tensor_x[list_index0][list_index1]...`。\n",
    "\n",
    "  `List`索引取值操作的是第零维,取出第零维对应位置的元素。\n",
    "\n",
    "  索引`List`数据类型必须是int、bool或两者混合。若数据类型为int,则取值在[`-dimension_shape`, `dimension_shape-1`]之间;若数据类型为bool, 则限制bool个数为对应维度长度,筛选对应维度上值为`True`的元素;若值为前两者混合,则bool类型的`True/False`将转为int类型的`1/0`。\n",
    "\n",
    "  `List`索引取值得到结果的`data_shape = list_index.shape + tensor_x.shape[1:]`。\n",
    "\n",
    "  例如,对一个`shape`为`(6, 4, 5)`的tensor通过`shape`为`(3,)`的tensor进行索引取值,取得结果的`shape`为`(3, 4, 5)`。\n",
    "\n",
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`List`索引取值。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "346dff88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_single:\n",
      "[[[ 6  7  8]\n",
      "  [ 9 10 11]]\n",
      "\n",
      " [[12 13 14]\n",
      "  [15 16 17]]\n",
      "\n",
      " [[ 0  1  2]\n",
      "  [ 3  4  5]]]\n",
      "data_multi:\n",
      "[[[ 6  7  8]\n",
      "  [ 9 10 11]]\n",
      "\n",
      " [[ 0  1  2]\n",
      "  [ 3  4  5]]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(4 * 2 * 3).reshape((4, 2, 3)))\n",
    "list_index0 = [1, 2, 0]\n",
    "list_index1 = [True, False, True]\n",
    "data_single = tensor_x[list_index0]\n",
    "data_multi = tensor_x[list_index0][list_index1]\n",
    "print('data_single:')\n",
    "print(data_single)\n",
    "print('data_multi:')\n",
    "print(data_multi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6d25c33",
   "metadata": {},
   "source": [
    "- `Tuple`索引取值\n",
    "\n",
    "  索引`Tuple`的数据类型可以为`int`、`bool`、`None`、`slice`、`ellipsis`、`Tensor`、`List`、`Tuple`。支持单层和多层`Tuple`索引取值,单层`Tuple`索引取值:`tensor_x[tuple_index]`,多层`Tuple`索引取值:`tensor_x[tuple_index0][tuple_index1]...`。`Tuple`中包含的`List`与`Tuple`包含元素规则与单独的`List`规则相同,其他元素规则与单独元素也相同。\n",
    "\n",
    "  索引`Tuple`中元素按照最终索引Broadcast规则,分为`Basic Index`、`Advanced Index`两类。`Basic Index`包含`slice`、`ellipsis`、`int`与`None`四种类型,`Advanced Index`包含`bool`、`Tensor`、`List`、`Tuple`等类型。索引过程中,所有的`Advanced Index`将会做Broadcast,若`Advaned Index`连续,最终broadcast shape将插入在第一个`Advanced Index`位置;若不连续,则broadcast shape插入在`0`位置。\n",
    "\n",
    "  索引里除`None`扩展对应维度,`bool`扩展对应维度后与`Advanced Index`做Broadcast。除`ellipsis`、`bool`、`None`外每个元素操作对应位置维度,即`Tuple`中第0个元素操作第零维,第1个元素操作第一维,以此类推。每个元素的索引规则与该元素类型索引取值规则一致。\n",
    "\n",
    "  `Tuple`索引里最多只有一个`ellipsis`,`ellipsis`前半部分索引元素从前往后对应`Tensor`第零维往后,后半部分索引元素从后往前对应`Tensor`最后一维往前,其他未指定的维度,代表全取。\n",
    "\n",
    "  元素里包含的`Tensor`数据类型必须是int型或bool型,int型可以是(int8, int16, int32, int64),值必须为非负数,且小于第零维长度。\n",
    "\n",
    "  例如,`tensor_x[0:3, 1, tensor_index] == tensor_x[(0:3, 1, tensor_index)]`,因为`0:3, 1, tensor_index`就是一个`Tuple`。\n",
    "\n",
    "  多层索引取值可以理解为,后一层索引取值在前一层索引取值结果上再进行`Tuple`索引取值。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2764cb66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data:\n",
      "[[[13]\n",
      "  [14]\n",
      "  [13]]\n",
      "\n",
      " [[12]\n",
      "  [15]\n",
      "  [14]]]\n"
     ]
    }
   ],
   "source": [
    "from mindspore import dtype as mstype\n",
    "import mindspore as ms\n",
    "import mindspore.numpy as np\n",
    "tensor_x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4)))\n",
    "tensor_index = ms.Tensor(np.array([[1, 2, 1], [0, 3, 2]]), mstype.int32)\n",
    "data = tensor_x[1, 0:1, tensor_index]\n",
    "print('data:')\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29242418",
   "metadata": {},
   "source": [
    "## 索引赋值\n",
    "\n",
    "对于形如: `tensor_x[index] = value`, `index`的类型支持`int`、`bool`、`ellipsis`、`slice`、`None`、`Tensor`、`List`、`Tuple`。\n",
    "\n",
    "`value`的类型支持`Number`、`Tuple`、`List`和`Tensor`。被赋的值会首先被转换为张量,数据类型与原张量(`tensor_x`)相符。\n",
    "\n",
    "当`value`为`Number`时,可以理解为将`tensor_x[index]`索引对应元素都更新为`Number`。\n",
    "\n",
    "当`value`为数组,即只包含`Number`的`Tuple`、`List`或`Tensor`时,`value.shape`需要可以与`tensor_x[index].shape`做广播,将`value`广播到`tensor_x[index].shape`后,更新`tensor_x[index]`对应的值。\n",
    "\n",
    "当`value`为`Tuple`或`List`时,若`value`中元素包含`Number`,`Tuple`,`List` 和 `Tensor`等多种类型,该`Tuple` 和 `List` 目前只支持一维。\n",
    "\n",
    "当`value`为`Tuple`或`List`,且存在`Tensor`时,非`Tensor`的元素会首先被转换为`Tensor`,然后这些`Tensor`在`axis=0`轴上打包之后成为一个新的赋值`Tensor`,这时按照`value`为`Tensor`的规则进行赋值。所有`Tensor`的数据类型必须保持一致。\n",
    "\n",
    "索引赋值可以理解为对索引到的位置元素按照一定规则进行赋值,所有索引赋值都不会改变原`Tensor`的`shape`。\n",
    "\n",
    "> 当索引中有多个元素指向原张量的同一个位置时,该值的更新受底层算子限制,可能出现随机的情况。因此暂不支持索引中重复对张量中一个位置的值反复更新。详情请见:[TensorScatterUpdate 算子介绍](https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/ops/mindspore.ops.TensorScatterUpdate.html)\n",
    ">\n",
    "> 当前只支持单层索引(`tensor_x[index] = value`),多层索引(`tensor_x[index1][index2]... = value`)暂不支持。\n",
    "\n",
    "- `int`索引赋值\n",
    "\n",
    "  支持单层`int`索引赋值:`tensor_x[int_index] = u`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e8669b85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[ 0.  1.  2.]\n",
      " [88. 88. 88.]]\n",
      "tensor_y:\n",
      "[[ 0.  1.  2.]\n",
      " [66. 88. 99.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_x[1] = 88.0\n",
    "tensor_y[1] = np.array([66, 88, 99]).astype(np.float32)\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b59501b3",
   "metadata": {},
   "source": [
    "- `bool`索引赋值\n",
    "\n",
    "  支持单层`bool`索引赋值:`tensor_x[bool_index] = u`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "345a4be6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [88. 88. 88.]]\n",
      "tensor_y:\n",
      "[[66. 88. 99.]\n",
      " [66. 88. 99.]]\n",
      "tensor_z:\n",
      "[[66. 88. 99.]\n",
      " [66. 88. 99.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_x[True] = 88.0\n",
    "tensor_y[True] = np.array([66, 88, 99]).astype(np.float32)\n",
    "tensor_z[True] = (66, 88, 99)\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5369a1c",
   "metadata": {},
   "source": [
    "- `ellipsis`索引赋值\n",
    "\n",
    "  支持单层`ellipsis`索引赋值,单层`ellipsis`索引赋值:`tensor_x[...] = u`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1210a82d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [88. 88. 88.]]\n",
      "tensor_y:\n",
      "[[22. 44. 55.]\n",
      " [22. 44. 55.]]\n",
      "tensor_z:\n",
      "[[11. 22. 33.]\n",
      " [44. 55. 66.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_x[...] = 88.0\n",
    "tensor_y[...] = np.array([[22, 44, 55], [22, 44, 55]])\n",
    "tensor_z[...] = ([11, 22, 33], [44, 55, 66])\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55c9066c",
   "metadata": {},
   "source": [
    "- `slice`索引赋值\n",
    "\n",
    "  支持单层`slice`索引赋值:`tensor_x[slice_index] = u`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "093915cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [ 3.  4.  5.]\n",
      " [ 6.  7.  8.]]\n",
      "tensor_y:\n",
      "[[88. 88. 88.]\n",
      " [88. 88. 88.]\n",
      " [ 6.  7.  8.]]\n",
      "tensor_z:\n",
      "[[11. 12. 13.]\n",
      " [11. 12. 13.]\n",
      " [ 6.  7.  8.]]\n",
      "tensor_k:\n",
      "[[11. 12. 13.]\n",
      " [14. 15. 16.]\n",
      " [ 6.  7.  8.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_k = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_x[0:1] = 88.0\n",
    "tensor_y[0:2] = 88.0\n",
    "tensor_z[0:2] = np.array([[11, 12, 13], [11, 12, 13]]).astype(np.float32)\n",
    "tensor_k[0:2] = ([11, 12, 13], (14, 15, 16))\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)\n",
    "print('tensor_k:')\n",
    "print(tensor_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ca2dc55",
   "metadata": {},
   "source": [
    "- `None`索引赋值\n",
    "\n",
    "  支持单层`None`索引赋值:`tensor_x[none_index] = u`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "387ffc8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [88. 88. 88.]]\n",
      "tensor_y:\n",
      "[[66. 88. 99.]\n",
      " [66. 88. 99.]]\n",
      "tensor_z:\n",
      "[[66. 88. 99.]\n",
      " [66. 88. 99.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)\n",
    "tensor_x[None] = 88.0\n",
    "tensor_y[None] = np.array([66, 88, 99]).astype(np.float32)\n",
    "tensor_z[None] = (66, 88, 99)\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4ed1fe0",
   "metadata": {},
   "source": [
    "- `Tensor`索引赋值\n",
    "\n",
    "  支持单层`Tensor`索引赋值,即`tensor_x[tensor_index] = u`。\n",
    "\n",
    "  当前支持索引Tensor为 `int` 型和 `bool` 型。\n",
    "\n",
    "  `int` 型示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7aecf687",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [ 3.  4.  5.]\n",
      " [88. 88. 88.]]\n",
      "tensor_y:\n",
      "[[11. 12. 13.]\n",
      " [ 3.  4.  5.]\n",
      " [11. 12. 13.]]\n",
      "tensor_z:\n",
      "[[11. 12. 13.]\n",
      " [ 3.  4.  5.]\n",
      " [11. 12. 13.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_index = np.array([[2, 0, 2], [0, 2, 0], [0, 2, 0]], np.int32)\n",
    "tensor_x[tensor_index] = 88.0\n",
    "tensor_y[tensor_index] = np.array([11.0, 12.0, 13.0]).astype(np.float32)\n",
    "tensor_z[tensor_index] = [11, 12, 13]\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "101848dc",
   "metadata": {},
   "source": [
    "  `bool` 型示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "123b7ceb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-1. -1. -1.]\n",
      " [ 3.  4.  5.]\n",
      " [-1. -1. -1.]]\n"
     ]
    }
   ],
   "source": [
    "from mindspore import dtype as mstype\n",
    "import mindspore as ms\n",
    "tensor_x = ms.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], mstype.float32)\n",
    "tensor_index = ms.Tensor([True, False, True], mstype.bool_)\n",
    "tensor_x[tensor_index] = -1\n",
    "print(tensor_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86066587",
   "metadata": {},
   "source": [
    "- `List`索引赋值\n",
    "\n",
    "  支持单层`List`索引赋值:`tensor_x[list_index] = u`。\n",
    "\n",
    "  `List`索引赋值和`List`索引取值对索引的支持一致。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2df0ba82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[88. 88. 88.]\n",
      " [88. 88. 88.]\n",
      " [ 6.  7.  8.]]\n",
      "tensor_y:\n",
      "[[11. 12. 13.]\n",
      " [ 3.  4.  5.]\n",
      " [ 6.  7.  8.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_index = np.array([[0, 1], [1, 0]]).astype(np.int32)\n",
    "tensor_x[[0, 1]] = 88.0\n",
    "tensor_y[[True, False, False]] = np.array([11, 12, 13]).astype(np.float32)\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38bcdbdd",
   "metadata": {},
   "source": [
    "- `Tuple`索引赋值\n",
    "\n",
    "  支持单层`Tuple`索引赋值:`tensor_x[tuple_index] = u`。\n",
    "\n",
    "  `Tuple`索引赋值和`Tuple`索引取值对索引的支持一致,但不支持`Tuple`中包含`None`。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "368380ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[ 0.  1.  2.]\n",
      " [ 3. 88. 88.]\n",
      " [ 6.  7.  8.]]\n",
      "tensor_y:\n",
      "[[ 0.  1.  2.]\n",
      " [88. 88.  5.]\n",
      " [88. 88.  8.]]\n",
      "tensor_z:\n",
      "[[ 0.  1.  2.]\n",
      " [11. 12.  5.]\n",
      " [11. 12.  8.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore.numpy as np\n",
    "tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32)\n",
    "tensor_index = np.array([0, 1]).astype(np.int32)\n",
    "tensor_x[1, 1:3] = 88.0\n",
    "tensor_y[1:3, tensor_index] = 88.0\n",
    "tensor_z[1:3, tensor_index] = np.array([11, 12]).astype(np.float32)\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)\n",
    "print('tensor_z:')\n",
    "print(tensor_z)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab4db2f4",
   "metadata": {},
   "source": [
    "## 索引增强赋值\n",
    "\n",
    "增强索引赋值,支持`+=`、`-=`、`*=`、`/=`、`%=`、`**=`、`//=`七种类型,`index`与`value`的规则约束与索引赋值相同。索引值支持`int`、`bool`、`ellipsis`、`slice`、`None`、`Tensor`、`List`、`Tuple`八种类型,赋值支持`Number`、`Tensor`、`Tuple`、`List`四种类型。  \n",
    "\n",
    "索引增强赋值可以理解为对索引到的位置元素按照一定规则进行取值,取值所得再与`value`进行操作符运算,最终将运算结果进行赋值,所有索引增强赋值都不会改变原`Tensor`的`shape`。\n",
    "\n",
    "> 当索引中有多个元素指向原张量的同一个位置时,该值的更新受底层算子限制,可能出现随机的情况。因此暂不支持索引中重复对张量中一个位置的值反复更新。详情请见:[TensorScatterUpdate 算子介绍](https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/ops/mindspore.ops.TensorScatterUpdate.html)。\n",
    ">\n",
    "> 目前索引中包含 `True`、`False` 和 `None`的情况暂不支持。\n",
    "\n",
    "- 规则与约束\n",
    "\n",
    "  与索引赋值相比,增加了取值与运算的过程。取值过程中`index`的约束规则与索引取值中`index`相同,支持`int`、`bool`、`Tensor`、`Slice`、`Ellipsis`、`None`、`List`与`Tuple`。上述几种类型的数据中所包含`int`值,需在`[-dim_size, dim_size-1]`闭合区间内。\n",
    "  运算过程中`value`的约束规则与索引赋值中`value`的约束规则相同,`value`类型需为(`Number`、`Tensor`、`List`、`Tuple`)之一,且`value`类型不是`Number`时, `value`的形状需要可以广播到`tensor_x[index]`的形状。\n",
    "\n",
    "  示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7482ac2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_x:\n",
      "[[ 0.  3.  4.  3.]\n",
      " [ 4.  7.  8.  7.]\n",
      " [ 8.  9. 10. 11.]]\n",
      "tensor_y:\n",
      "[[ 0.  1.  2.  3.]\n",
      " [ 0.  2.  4.  6.]\n",
      " [ 8.  9. 10. 11.]]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "tensor_x = ms.Tensor(np.arange(3 * 4).reshape(3, 4).astype(np.float32))\n",
    "tensor_y = ms.Tensor(np.arange(3 * 4).reshape(3, 4).astype(np.float32))\n",
    "tensor_x[[0, 1], 1:3] += 2\n",
    "tensor_y[[1], ...] -= [4, 3, 2, 1]\n",
    "print('tensor_x:')\n",
    "print(tensor_x)\n",
    "print('tensor_y:')\n",
    "print(tensor_y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}