python – How to calculate the sum of all columns of a 2D numpy array (efficiently)

python – How to calculate the sum of all columns of a 2D numpy array (efficiently)

Check out the documentation for numpy.sum, paying particular attention to the axis parameter. To sum over columns:

>>> import numpy as np
>>> a = np.arange(12).reshape(4,3)
>>> a.sum(axis=0)
array([18, 22, 26])

Or, to sum over rows:

>>> a.sum(axis=1)
array([ 3, 12, 21, 30])

Other aggregate functions, like numpy.mean, numpy.cumsum and numpy.std, e.g., also take the axis parameter.

From the Tentative Numpy Tutorial:

Many unary operations, such as computing the sum of all the elements
in the array, are implemented as methods of the ndarray class. By
default, these operations apply to the array as though it were a list
of numbers, regardless of its shape. However, by specifying the axis
parameter you can apply an operation along the specified axis of an
array:

Other alternatives for summing the columns are

numpy.einsum(ij->j, a)

and

numpy.dot(a.T, numpy.ones(a.shape[0]))

If the number of rows and columns is in the same order of magnitude, all of the possibilities are roughly equally fast:

enter

If there are only a few columns, however, both the einsum and the dot solution significantly outperform numpys sum (note the log-scale):

enter


Code to reproduce the plots:

import numpy
import perfplot


def numpy_sum(a):
    return numpy.sum(a, axis=1)


def einsum(a):
    return numpy.einsum(ij->i, a)


def dot_ones(a):
    return numpy.dot(a, numpy.ones(a.shape[1]))


perfplot.save(
    out1.png,
    # setup=lambda n: numpy.random.rand(n, n),
    setup=lambda n: numpy.random.rand(n, 3),
    n_range=[2**k for k in range(15)],
    kernels=[numpy_sum, einsum, dot_ones],
    logx=True,
    logy=True,
    xlabel=len(a),
    )

python – How to calculate the sum of all columns of a 2D numpy array (efficiently)

Use the axis argument:

>> numpy.sum(a, axis=0)
  array([18, 22, 26])

Leave a Reply

Your email address will not be published. Required fields are marked *