Follow

Follow
Understanding numpy.argmax() Function In Python

Understanding numpy.argmax() Function In Python

An in-depth guide to argmax function in NumPy

Sachin Pal's photo
Sachin Pal
·Sep 22, 2022·

8 min read

Play this article

Table of contents

  • Argmax function
  • Conclusion

NumPy is most often used to handle or work with arrays (multidimensional, masked) and matrices. It has a collection of functions and methods to operate on arrays like statistical operations, mathematical and logical operations, shape manipulation, linear algebra, and much more.

Argmax function

numpy.argmax() is one of the functions provided by NumPy that is used to return the indices of the maximum element along an axis from the specified array.

Syntax

numpy.argmax(a, axis=None, out=None)

Parameters:

  • a - The input array we will work on

  • axis - It is optional. We can specify an axis like 1 or 0 to find the maximum value index horizontally or vertically.

  • out - By default, it is None. It provides a feature to insert the output to the out array, but the array should be of appropriate shape and dtype.

Return value

The array of integers is returned with the indices of max values from the array a with the same shape as a.shape with the dimension along the axis removed.

Finding the index of the max element

Let's see the basic example to find the index of the max element in the array.

Working with a 1D array without specifying the axis

# Importing Numpy
import numpy as np

# Working with a 1D array
inp_arr = np.array([5, 2, 9, 4, 2])

# Applying argmax() function
max_elem_index = np.argmax(inp_arr)

# Printing index
print("MAX ELEMENT INDEX:", max_elem_index)

Output

MAX ELEMENT INDEX: 2

Working with a 2D array without specifying the axis

When we work with 2D arrays in numpy and try to find the index of the max element without specifying the axis, the array we are working on has the element index the same as the 1D array or flattened array.

Finding the index of the max element in a 2D array without specifying the axis

# Importing Numpy
import numpy as np

# Creating 2D array
arr = np.random.randint(16, size=(4, 4))

# Array preview
print("INPUT ARRAY: \n", arr)

# Applying argmax()
elem_index = np.argmax(arr)

# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)

Output

INPUT ARRAY: 
 [[ 5  5  4 12]
 [12 15 13  0]
 [11 13  2  6]
 [ 6  8  8  9]]

MAX ELEMENT INDEX: 5

Finding the index of the max element along the axis

Things will change when we specify the axis and try to find the index of the max element along it.

When the axis is 0

When we specify the axis=0, then the argmax function will find the index of the max element vertically in the multidimensional array that the user specified. Let's understand it by an illustration below.

Finding the indices of the max elements along axis 0

In the above illustration, argmax() function returned the max element index from the 1st column which is 1 and then returned the max element index from the 2nd column which is again 1, and the same does for the 3rd and the 4th column.

# Importing Numpy
import numpy as np

# Creating 2D array
arr = np.random.randint(16, size=(4, 4))

# Array preview
print("INPUT ARRAY: \n", arr)

# Applying argmax()
elem_index = np.argmax(arr, axis=0)

# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)

Output

INPUT ARRAY: 
 [[ 8  6 10  3]
 [ 4  5  9  1]
 [ 6 15 13 13]
 [ 4 14 15 13]]

MAX ELEMENT INDEX: [0 2 3 2]

When the axis is 1

When we specify the axis=1, the argmax function will find the index of the max element horizontally in the multidimensional array that the user specified. Let's understand it by an illustration below.

Finding the indices of the max elements along axis 1

In the above illustration, argmax() function returned the max element index from the 1st row which is 2 and then returned the max element index from the 2nd row which is 1, and the same does for the 3rd and the 4th row.

Code example

# Importing Numpy
import numpy as np

# Creating 2D array
arr = np.random.randint(12, size=(4, 3))

# Array preview
print("INPUT ARRAY: \n", arr)

# Applying argmax()
elem_index = np.argmax(arr, axis=1)

# Displaying max element index
print("\nMAX ELEMENT INDEX:", elem_index)

Output

INPUT ARRAY: 
 [[ 7  8  0]
 [ 3  0 11]
 [ 7  6  0]
 [10  8  1]]

MAX ELEMENT INDEX: [1 2 0 0]

Multiple occurrences of the highest value

Sometimes, we can come across multidimensional arrays with multiple occurrences of the highest values along the particular axis, then what will happen?

The argmax() function will return the index of the highest value that occurs first in a particular axis.

Illustration showing multiple occurrences of the highest values along axis 0.

Multiple occurrences of the highest values along axis 0

Illustration showing multiple occurrences of the highest values along axis 1.

Multiple occurrences of the highest values along axis 1

Code example

# Importing Numpy
import numpy as np

# Defining the array
arr = np.array([[2, 14, 9, 4, 5],
                [7, 14, 53, 10, 4],
                [91, 2, 41, 6, 91]])

# Displaying the highest element index along axis 0
print("MAX ELEMENT INDEX:", np.argmax(arr, axis=0))

# Displaying the highest element index along axis 1
print("\nMAX ELEMENT INDEX:", np.argmax(arr, axis=1))

# Flattening the array, making it a 1D array
flattened_arr = arr.flatten()
print("\nThe array is flattened into 1D array:", flattened_arr)

# Displaying the highest element index
print("\nMAX ELEMENT INDEX:", np.argmax(flattened_arr))

Output

MAX ELEMENT INDEX: [2 0 1 1 2]

MAX ELEMENT INDEX: [1 2 0]

The array is flattened into 1D array: [ 2 14  9  4  5  7 14 53 10  4 91  2 41  6 91]

MAX ELEMENT INDEX: 10

Explanation

In the above code, when we try to find the indices of the max elements along the axis 0, we got an array with values [2 0 1 1 2], if we look at the 2nd column, 14 is the highest value at the 0th and the 1st index, we got 0 because the value 14 at the 0th index occurred first when finding the highest value.

The same goes for the array we obtained in the second output when we provided the axis 1, in the 3rd row, 91 is the highest value at the 0th and the 4th index, the value 91 at the 0th index occurred first when finding the highest value hence we got output 0.

Using the out parameter

The out parameter in numpy.argmax() function is optional and by default it is None.

The out parameter stores the output array(containing indices of the max elements in a particular axis) in a numpy array. The array specified in the out parameter should be of shape and dtype, the same as the input array.

Code Example

# Importing Numpy
import numpy as np

# Creating array filled with zeroes which then be replaced
out_array = np.zeros((4,), dtype=int)
print("ARRAY w/ ZEROES:", out_array)

# Input array
arr = np.random.randint(16, size=(4, 4))
print("INPUT ARRAY:\n", arr)

# Storing the indices of the max elements(axis=1) in the out_array
print("\nAXIS 1:", np.argmax(arr, axis=1, out=out_array))

# Storing the indices of the max elements(axis=0) in the out_array
print("\nAXIS 0:", np.argmax(arr, axis=0, out=out_array))

Output

ARRAY w/ ZEROES: [0 0 0 0]
INPUT ARRAY:
 [[ 4  2 14 15]
 [ 6 15  2  1]
 [13  6 13  3]
 [ 5  1 13  9]]

AXIS 1: [3 1 0 2]

AXIS 0: [2 1 0 0]

Explanation

We created an array filled with zeroes named out_array, and we defined the shape and dtype same as the input array and then used the numpy.argmax() function to get the indices of the max elements along the axis 1 and 0 and stored them in the out_array that we defined earlier.

The numpy.zeros() has by default dtype float that's why we defined the dtype in the above code because our input array has the dtype=int.

If we didn't specify the dtype in the above code, it would throw an error.

# Importing Numpy
import numpy as np

# Creating array filled with zeroes without specifying dtype
out_array = np.zeros((4,))
print("ARRAY w/ ZEROES:", out_array)

# Input array
arr = np.random.randint(16, size=(4, 4))
print("INPUT ARRAY:\n", arr)

print("\nAXIS 1:", np.argmax(arr, axis=1, out=out_array))

Output

ARRAY w/ ZEROES: [0. 0. 0. 0.]
INPUT ARRAY:
 [[14  9  3  4]
 [ 9  2  4  8]
 [ 5  1  9  1]
 [ 6  0 10  7]]

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Conclusion

That was the insight of the argmax() function in NumPy. Let's review what we've learned:

  • numpy.argmax() function returns the index of the highest value in an array. If the maximum value occurs more than once in an array(multidimensional or flattened array), then the argmax() function will return the index of the highest value which occurred first.

  • We can specify the axis parameter when working with a multidimensional array to get the result along a particular axis. If we specify axis=0, then we'll get the index of the highest values vertically in the multidimensional array, and for axis=1, we'll get the result horizontally in the multidimensional array.

  • We can store the output in another array specified in the out parameter; however, the array should be compatible with the input array.


That's all for now

Keep Coding✌✌

Did you find this article valuable?

Support Sachin Pal by becoming a sponsor. Any amount is appreciated!

See recent sponsors Learn more about Hashnode Sponsors
 
Share this