# Understanding numpy.argmax() Function In Python

## An in-depth guide to argmax function in NumPy

NumPy is most often used to handle or work with arrays (multidimensional, masked) and matrices. It has a collection of functions and methods to operate on arrays like statistical operations, mathematical and logical operations, shape manipulation, linear algebra, and much more.

## Argmax function

numpy.argmax() is one of the functions provided by NumPy that is used to **return the indices of the maximum element along an axis from the specified array**.

### Syntax

numpy.argmax(a, axis=None, out=None)

Parameters:

`a`

- The input array we will work on`axis`

- It is optional. We can specify an axis like 1 or 0 to find the maximum value index horizontally or vertically.`out`

- By default, it is None. It provides a feature to insert the output to the out array, but the array should be of appropriate shape and dtype.

Return value

The array of integers is returned with the indices of max values from the array

`a`

with the same shape as`a.shape`

with the dimension along theremoved.axis

### Finding the index of the max element

Let's see the basic example to find the index of the max element in the array.

**Working with a 1D array without specifying the axis**

```
# Importing Numpy
import numpy as np
# Working with a 1D array
inp_arr = np.array([5, 2, 9, 4, 2])
# Applying argmax() function
max_elem_index = np.argmax(inp_arr)
# Printing index
print("MAX ELEMENT INDEX:", max_elem_index)
```

**Output**

```
MAX ELEMENT INDEX: 2
```

**Working with a 2D array without specifying the axis**

When we work with 2D arrays in numpy and try to find the index of the max element without specifying the axis, the array we are working on has the element index the same as the 1D array or flattened array.

```
# Importing Numpy
import numpy as np
# Creating 2D array
arr = np.random.randint(16, size=(4, 4))
# Array preview
print("INPUT ARRAY: \n", arr)
# Applying argmax()
elem_index = np.argmax(arr)
# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)
```

**Output**

```
INPUT ARRAY:
[[ 5 5 4 12]
[12 15 13 0]
[11 13 2 6]
[ 6 8 8 9]]
MAX ELEMENT INDEX: 5
```

### Finding the index of the max element along the axis

Things will change when we specify the axis and try to find the index of the max element along it.

**When the axis is 0**

When we specify the `axis=0`

, then the ** argmax** function will find the index of the max element vertically in the multidimensional array that the user specified. Let's understand it by an illustration below.

In the above illustration,

argmax()function returned the max element index from the 1^{st}column which is1and then returned the max element index from the 2^{nd}column which is again1, and the same does for the 3^{rd}and the 4^{th}column.

```
# Importing Numpy
import numpy as np
# Creating 2D array
arr = np.random.randint(16, size=(4, 4))
# Array preview
print("INPUT ARRAY: \n", arr)
# Applying argmax()
elem_index = np.argmax(arr, axis=0)
# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)
```

**Output**

```
INPUT ARRAY:
[[ 8 6 10 3]
[ 4 5 9 1]
[ 6 15 13 13]
[ 4 14 15 13]]
MAX ELEMENT INDEX: [0 2 3 2]
```

**When the axis is 1**

When we specify the `axis=1`

, the ** argmax** function will find the index of the max element horizontally in the multidimensional array that the user specified. Let's understand it by an illustration below.

In the above illustration,

argmax()function returned the max element index from the 1^{st}row which is2and then returned the max element index from the 2^{nd}row which is1, and the same does for the 3^{rd}and the 4^{th}row.

**Code example**

```
# Importing Numpy
import numpy as np
# Creating 2D array
arr = np.random.randint(12, size=(4, 3))
# Array preview
print("INPUT ARRAY: \n", arr)
# Applying argmax()
elem_index = np.argmax(arr, axis=1)
# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)
```

**Output**

```
INPUT ARRAY:
[[ 7 8 0]
[ 3 0 11]
[ 7 6 0]
[10 8 1]]
MAX ELEMENT INDEX: [1 2 0 0]
```

### Multiple occurrences of the highest value

Sometimes, we can come across multidimensional arrays with multiple occurrences of the highest values along the particular axis, then what will happen?

The ** argmax()** function will return

**the index of the highest value that occurs first in a particular axis**.

Illustration showing multiple occurrences of the highest values along axis 0.

Illustration showing multiple occurrences of the highest values along axis 1.

**Code example**

```
# Importing Numpy
import numpy as np
# Defining the array
arr = np.array([[2, 14, 9, 4, 5],
[7, 14, 53, 10, 4],
[91, 2, 41, 6, 91]])
# Displaying the highest element index along axis 0
print("MAX ELEMENT INDEX:", np.argmax(arr, axis=0))
# Displaying the highest element index along axis 1
print("\nMAX ELEMENT INDEX:", np.argmax(arr, axis=1))
# Flattening the array, making it a 1D array
flattened_arr = arr.flatten()
print("\nThe array is flattened into 1D array:", flattened_arr)
# Displaying the highest element index
print("\nMAX ELEMENT INDEX:", np.argmax(flattened_arr))
```

**Output**

```
MAX ELEMENT INDEX: [2 0 1 1 2]
MAX ELEMENT INDEX: [1 2 0]
The array is flattened into 1D array: [ 2 14 9 4 5 7 14 53 10 4 91 2 41 6 91]
MAX ELEMENT INDEX: 10
```

**Explanation**

In the above code, when we try to find the indices of the max elements along the ** axis 0**, we got an array with values

`[2 0 1 1 2]`

, if we look at the **2**column, 14 is the highest value at the

^{nd}**0**and the

^{th}**1**index, we got

^{st}`0`

because the value 14 at the 0^{th}index occurred first when finding the highest value.

The same goes for the array we obtained in the second output when we provided the ** axis 1**, in the

**3**row, 91 is the highest value at the

^{rd}**0**and the

^{th}**4**index, the value 91 at the

^{th}**0**index occurred first when finding the highest value hence we got output

^{th}`0`

.### Using the *out* parameter

*out*

The `out`

parameter in ** numpy.argmax()** function is optional and by default it is None.

The `out`

parameter stores the output array(containing indices of the max elements in a particular axis) in a numpy array. The array specified in the out parameter should be of **shape** and **dtype**, the same as the input array.

**Code Example**

```
# Importing Numpy
import numpy as np
# Creating array filled with zeroes which then be replaced
out_array = np.zeros((4,), dtype=int)
print("ARRAY w/ ZEROES:", out_array)
# Input array
arr = np.random.randint(16, size=(4, 4))
print("INPUT ARRAY:\n", arr)
# Storing the indices of the max elements(axis=1) in the out_array
print("\nAXIS 1:", np.argmax(arr, axis=1, out=out_array))
# Storing the indices of the max elements(axis=0) in the out_array
print("\nAXIS 0:", np.argmax(arr, axis=0, out=out_array))
```

**Output**

```
ARRAY w/ ZEROES: [0 0 0 0]
INPUT ARRAY:
[[ 4 2 14 15]
[ 6 15 2 1]
[13 6 13 3]
[ 5 1 13 9]]
AXIS 1: [3 1 0 2]
AXIS 0: [2 1 0 0]
```

**Explanation**

We created an array filled with zeroes named `out_array`

, and we defined the ** shape** and

**same as the input array and then used the**

*dtype*`numpy.argmax()`

function to get the indices of the max elements along the axis 1 and 0 and stored them in the `out_array`

that we defined earlier.The `numpy.zeros()`

has by default *dtype*`float`

that's why we defined the ** dtype** in the above code because our input array has the

`dtype=int`

.**If we didn't specify the dtype in the above code, it would throw an error**.

```
# Importing Numpy
import numpy as np
# Creating array filled with zeroes without specifying dtype
out_array = np.zeros((4,))
print("ARRAY w/ ZEROES:", out_array)
# Input array
arr = np.random.randint(16, size=(4, 4))
print("INPUT ARRAY:\n", arr)
print("\nAXIS 1:", np.argmax(arr, axis=1, out=out_array))
```

**Output**

```
ARRAY w/ ZEROES: [0. 0. 0. 0.]
INPUT ARRAY:
[[14 9 3 4]
[ 9 2 4 8]
[ 5 1 9 1]
[ 6 0 10 7]]
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
```

## Conclusion

That was the insight of the argmax() function in NumPy. Let's review what we've learned:

`numpy.argmax()`

function returns the index of the highest value in an array. If the maximum value occurs more than once in an array(multidimensional or flattened array), then thefunction will return the index of the highest value which occurred first.*argmax()*We can specify the axis parameter when working with a multidimensional array to get the result along a particular axis. If we specify

**axis=0**, then we'll get the index of the highest values vertically in the multidimensional array, and for**axis=1**, we'll get the result horizontally in the multidimensional array.We can store the output in another array specified in the

**out**parameter; however, the array should be compatible with the input array.

**That's all for now**

**Keep CodingâœŒâœŒ**

### Did you find this article valuable?

Support **Sachin Pal** by becoming a sponsor. Any amount is appreciated!