Python PyTorch – rsqrt() method
PyTorch rsqrt() method computes the reciprocal of the square root of each element of the input tensor. It accepts both real and complex-valued tensors. It returns ‘NaN‘ (not a number) as the reciprocal of the square root of a negative number and ‘inf‘ for zero. Mathematically, the below formula is used to calculate the reciprocal of the square root of a number input.
Syntax: torch.rsqrt(input, *, out=None)
- input: the input tensor.
- out: the output tensor. It’s an optional keyword argument.
Return: it returns a new tensor with the computed reciprocal of the square-root of each of the elements of input.
In this example we use torch.rsqrt() method to compute the reciprocal of the square root of a one-dimensional float tensor. The tensor consists of zero and negative numbers also. Here, the third element of the input tensor is zero and the rsqrt of zero is ‘inf’, and its fourth element is a negative number and it’s rsqrt is ‘nan’.
tensor a: tensor([ 1.2000, 0.3200, 0.0000, -32.3000, 4.0000]) rsqrt of a: tensor([0.9129, 1.7678, inf, nan, 0.5000])
In the example below, we compute the rsqrt of a one-dimensional complex tensor using torch.rsqrt() method. Note that the complex numbers are generated using a random generator, so you may notice getting a different number at each run.
tensor([-0.4207-0.9085j, -0.2920+0.0372j, 0.9237+0.2633j, -0.1313+0.5933j])
rsqrt of a:
tensor([0.5381+0.8422j, 0.1168-1.8396j, 1.0105-0.1412j, 0.8032-1.0003j])
In the example below, we compute the rsqrt of a 3-D tensor using torch.rsqrt() method. In this example also we will generate the numbers using a random generator, so you may notice getting a different number at each run. In the same way as one-dimensional tensors, the rsqrt of a each element of multidimensional tensor is computed.
tensor a: tensor([[[-0.7205, -1.3897], [ 1.0028, 0.3652], [ 0.8731, -0.7459]], [[-0.9512, 1.8421], [ 0.2855, 0.3749], [-0.8577, 0.6472]]]) rsqrt of a: tensor([[[ nan, nan], [0.9986, 1.6547], [1.0702, nan]], [[ nan, 0.7368], [1.8715, 1.6333], [ nan, 1.2431]]])