# NumPy Where: Understanding np.where()

The NumPy `where()` function is like a vectorized switch that you can use to combine two arrays. For example, let’s say you have an array with some data called and you want to create a new array with 1 whenever an element in the data array is more than one standard deviation from the mean and -1 for all other elements.

This is a perfect use case for `np.where()`. First, create a boolean array for your conditional, and then use call `np.where()`:

``````import numpy as np
import pandas as pd

condition = np.abs(df.revenue - df.revenue.mean()) > df.revenue.std()

np.where(condition, 1, -1)

# Expected result
# array([ 1, -1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,
#        -1, -1,  1, -1, -1, -1, -1,  1,  1,  1,  1,  1,  1])``````

The arguments here are:

• `condition`: a NumPy array of elements that evaluate to True or False
• `x`: an optional array-like result for elements that evaluate to True
• `y`: an optional array-like result for elements that evaluate to False

The elements of `condition` don’t actually need to have a boolean type as long as they can be coerced to a boolean (e.g. non-zero integers are interpreted as True). Also, both `x` and `y` are optional, but if you provide one, you need to provide both. Additionally, the input arrays can have any shape so you can use this as a multi-dimensional switch.

One thing to watch out for: the return value takes a different form if you don’t supply `x` and `y`. In that case, `np.where()` returns the indices of the true elements (for a 1-D vector) and the indices for all axes where the elements are true for higher dimensional cases. This is equivalent to `np.argwhere()` except that the index arrays are split by axis.

You can see how this works by calling `np.stack()` on the result of `np.where()`:

``````x = np.eye(4)
np.stack(np.where(x), -1) == np.argwhere(x)

# Expected result
# array([[ True,  True],
#        [ True,  True],
#        [ True,  True],
#        [ True,  True]])``````

This makes `np.where()` without the `x` and `y` inputs equivalent to calling the `.nonzero()` method on the condition array:

``````np.stack(x.nonzero(), -1) == np.argwhere(x)

# Expected result
# array([[ True,  True],
#        [ True,  True],
#        [ True,  True],
#        [ True,  True]])``````

Multi-dimensional binary cross entropy

Now that we know how the API works, let’s look at another example: multi-dimensional binary cross entropy. Say we have a 3-D array of binary class probabilities `yhat` and a 3-D array of binary labels `y`. The one-liner formula for binary cross-entropy is the following:

``-(y * np.log(yhat) + (1 - y) * np.log(1 - yhat)).mean()``

This does work in the multi-dimensional case because NumPy defaults to element-wise operations. The multiplication of `y` and `1 - y` times the log terms function like switches. When `y == 1` the first term is included and when `y == 0` the second term is included:

``````np.random.seed(1)

yhat = np.random.uniform(size=(3, 3, 3))
y = np.random.randint(0, 2, size=(3, 3, 3))

-(y * np.log(yhat) + (1 - y) * np.log(1 - yhat)).mean()

# Expected result
# 1.221865004504288``````

But we can accomplish the same thing with `np.where()`:

``````-np.where(y, np.log(yhat), np.log(1 - yhat)).mean()

# Expected result
# 1.221865004504288``````

Pretty cool! This is not necessarily a better implementation in any important way, but it does make the function of the `y` and `1 - y` terms very clear.

Connect

Contact

ben [at] sparrow [dot] dev

Email List