""" Utilities that manipulate strides to achieve desirable effects. An explanation of strides can be found in the "ndarray.rst" file in the NumPy reference guide. """ from __future__ import division, absolute_import, print_function import numpy as np __all__ = ['broadcast_arrays'] class DummyArray(object): """Dummy object that just exists to hang __array_interface__ dictionaries and possibly keep alive a reference to a base array. """ def __init__(self, interface, base=None): self.__array_interface__ = interface self.base = base def as_strided(x, shape=None, strides=None): """ Make an ndarray from the given array with the given shape and strides. """ interface = dict(x.__array_interface__) if shape is not None: interface['shape'] = tuple(shape) if strides is not None: interface['strides'] = tuple(strides) array = np.asarray(DummyArray(interface, base=x)) # Make sure dtype is correct in case of custom dtype if array.dtype.kind == 'V': array.dtype = x.dtype return array def broadcast_arrays(*args): """ Broadcast any number of arrays against each other. Parameters ---------- `*args` : array_likes The arrays to broadcast. Returns ------- broadcasted : list of arrays These arrays are views on the original arrays. They are typically not contiguous. Furthermore, more than one element of a broadcasted array may refer to a single memory location. If you need to write to the arrays, make copies first. Examples -------- >>> x = np.array([[1,2,3]]) >>> y = np.array([[1],[2],[3]]) >>> np.broadcast_arrays(x, y) [array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]), array([[1, 1, 1], [2, 2, 2], [3, 3, 3]])] Here is a useful idiom for getting contiguous copies instead of non-contiguous views. >>> [np.array(a) for a in np.broadcast_arrays(x, y)] [array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]), array([[1, 1, 1], [2, 2, 2], [3, 3, 3]])] """ args = [np.asarray(_m) for _m in args] shapes = [x.shape for x in args] if len(set(shapes)) == 1: # Common case where nothing needs to be broadcasted. return args shapes = [list(s) for s in shapes] strides = [list(x.strides) for x in args] nds = [len(s) for s in shapes] biggest = max(nds) # Go through each array and prepend dimensions of length 1 to each of # the shapes in order to make the number of dimensions equal. for i in range(len(args)): diff = biggest - nds[i] if diff > 0: shapes[i] = [1] * diff + shapes[i] strides[i] = [0] * diff + strides[i] # Chech each dimension for compatibility. A dimension length of 1 is # accepted as compatible with any other length. common_shape = [] for axis in range(biggest): lengths = [s[axis] for s in shapes] unique = set(lengths + [1]) if len(unique) > 2: # There must be at least two non-1 lengths for this axis. raise ValueError("shape mismatch: two or more arrays have " "incompatible dimensions on axis %r." % (axis,)) elif len(unique) == 2: # There is exactly one non-1 length. The common shape will take # this value. unique.remove(1) new_length = unique.pop() common_shape.append(new_length) # For each array, if this axis is being broadcasted from a # length of 1, then set its stride to 0 so that it repeats its # data. for i in range(len(args)): if shapes[i][axis] == 1: shapes[i][axis] = new_length strides[i][axis] = 0 else: # Every array has a length of 1 on this axis. Strides can be # left alone as nothing is broadcasted. common_shape.append(1) # Construct the new arrays. broadcasted = [as_strided(x, shape=sh, strides=st) for (x, sh, st) in zip(args, shapes, strides)] return broadcasted