Source code for histpy.axes

import logging
logger = logging.getLogger(__name__)

import numpy as np

from copy import copy,deepcopy

from .axis import Axis

from sparse import SparseArray

import astropy.units as u

[docs]class Axes: """ Holds a list of axes. The operator :code:`Axes[key]` return a subset of these. Key can be either the index or the label. If the key is a single index, a single Axis object will be returned Args: edges (array or list of arrays or Axis): Definition of bin edges. labels (array of str): Optionally label the axes for easier indexing. Will override the labels of edges, if they are Axis objects axis_scale (str or array): Bin center mode e.g. `"linear"` or `"log"`. See Axis.axis_scale. If not an array, all axes will have this mode. """ def __init__(self, edges, labels=None, axis_scale = None): # Standarize axes as list of Axis if isinstance(edges, Axes): # From another Axes object self._axes = copy(edges._axes) elif isinstance(edges, Axis): self._axes = [edges] elif np.isscalar(edges): if np.isscalar(edges): raise TypeError("'edges' can't be a scalar") else: # Either a list with edges (1D), or a list of lists of edges (ND) if all(np.ndim(a) == 0 for a in edges): #1D histogram self._axes = [Axis(edges)] else: #Multi-dimensional histogram. self._axes = [axis if isinstance(axis, Axis) else Axis(axis) for axis in edges] #Override labels if nedeed if labels is not None: if np.isscalar(labels): labels = [labels] if len(labels) != self.ndim: raise ValueError("Edges - labels size mismatch") for n,label in enumerate(labels): self._axes[n] = Axis(self._axes[n], label) self._update_labels_index() #Override scale if nedeed if axis_scale is not None: if np.isscalar(axis_scale): axis_scale = self.ndim*[axis_scale] if len(axis_scale) != self.ndim: raise ValueError("Edges - axis_scale size mismatch") for mode,ax in zip(axis_scale, self._axes): ax.axis_scale = mode def _update_labels_index(self): #Maps labels to axes indices. Only keep non-None labels = np.array([a.label for a in self._axes]) non_none_labels = labels[labels != None] if len(np.unique(non_none_labels)) != len(non_none_labels): raise ValueError("Labels can't repeat") self._labels = {} for n,label in enumerate(labels): if label is not None: self._labels[label] = n def __len__(self): return self.ndim def __iter__(self): return iter(self._axes) @property def ndim(self): """ Number of axes """ return len(self._axes)
[docs] def label_to_index(self, key): """ Turn a key or list of keys, either indices or labels, into indices Args: key (int or str): Index or label Return: int: Index """ if isinstance(key, (int, np.integer)): return key if (isinstance(key, (np.ndarray, list, tuple, range)) and not isinstance(key, str)): return tuple(self.label_to_index(k) for k in key) if isinstance(key, slice): return np.arange(*key.indices(self.ndim)) else: #Label try: return self._labels[key] except KeyError: logger.error("Axis with label {} not found".format(key)) raise
def __getitem__(self, key): indices = self.label_to_index(key) if np.isscalar(indices): return self._axes[indices] else: return Axes([self._axes[i] for i in indices]) def __setitem__(self, key, new): if not isinstance(new, Axis): new = Axis(new) key = self.label_to_index(key) if new.nbins != self._axes[key].nbins: raise ValueError("Can't assign new axis with different number of bins") self._axes[key] = new def __eq__(self, other): return all([a1 == a2 for a1,a2 in zip(self._axes,other._axes)]) def __array__(self): return np.array(self._axes)
[docs] def find_bin(self, *values): """ Return one or more indices corresponding to the bin this value or set of values correspond to. You can pass either an array, or specified the values as different arguments. i.e. :code:`h.find_bin(x,y,z)` = :code:`h.find_bin([x,y,z])` Multiple entries can be passed at once. e.g. :code:`h.find_bin([x0, x1, x2])`, :code:`h.find_bin([x0, x1],[y0, y1],[z0, z1])`, :code:`h.find_bin([[x0, x1],[y0, y1],[z0, z1]])` Args: values (float or array): Vaule or list of values. Either shape N or ndim x N, where N is the number of entries. Return: int or tuple: Bin index """ # Handle 1D if self.ndim == 1: # 1D hist, any shape of values works. The output has the same shape if len(values) != 1: raise ValueError("Mismatch between values shape and number of axes") return self[0].find_bin(values[0]) # >=2D case # Sanitize and standarize if len(values) == 1: # e.g. ([x,y]) or ([[x0,x1], [y0,y1]]), NOT (x,y,z), [[x0,x1], [y0,y1]] values = tuple(values[0]) if len(values) != self.ndim: raise ValueError("Mismatch between values shape and number of axes") return tuple(axis.find_bin(val) for val,axis in zip(values, self._axes))
[docs] def interp_weights(self, *values): """ Get the bins and weights to linearly interpolate between bins. The bin contents are assigned to the center of the bin. Args: values (float or array): Coordinates within the axes to interpolate. Returns: array of tuples of int, array of floats: Bins and weights to use. Shaped (2^ndim, N). Bins is an array of tupples for multi-dimensional histograms. """ # Handle 1D if self.ndim == 1: # 1D hist, any shape of values works. The output has the same shape if len(values) != 1: raise ValueError("Mismatch between values shape and number of axes") return self._axes[0].interp_weights(values[0]) # >=2D case # Sanitize and standarize if len(values) == 1: # e.g. ([x,y]) or ([[x0,x1], [y0,y1]]), NOT (x,y,z), [[x0,x1], [y0,y1]] values = tuple(values[0]) if len(values) != self.ndim: raise ValueError("Mismatch between values shape and number of axes") # Strip units. Accounts for conversion factor, before stripped by broadcast values = tuple(axis._strip_units(value) for axis,value in zip(self._axes,values)) # Broadcast values = np.broadcast_arrays(*values) values_shape = values[1].shape # Get the bin/weights for each individual axis dim_bins = np.empty(self.ndim, dtype = 'O') dim_weights = np.empty(self.ndim, dtype = 'O') for dim,value in enumerate(values): bins,weights = self._axes[dim].interp_weights(value) dim_bins[dim] = bins dim_weights[dim] = weights npoints = 2**self.ndim bins = np.empty(tuple([npoints]) + values_shape, dtype = 'O') weights = np.empty(tuple([npoints]) + values_shape) # Combine them. e.g. for 2D this results in # weights = [dim_weights[0][0]*dim_weights[1][0], # dim_weights[0][1]*dim_weights[1][0], # dim_weights[0][0]*dim_weights[1][1], # dim_weights[0][1]*dim_weights[1][1]] # bins = [(dim_bins[0][0], dim_bins[1][0]), # (dim_bins[0][1], dim_bins[1][0]), # (dim_bins[0][0], dim_bins[1][1]), # (dim_bins[0][1], dim_bins[1][1])] # bit_masks = [0b001, 0b010, 0b100, ...] bit_masks = 2**np.array(range(self.ndim)) for n in range(npoints): weight = 1 bin_list = np.empty(tuple([self.ndim]) + values_shape, dtype = int) # Since there are two weights per axis, we use bit # masking to loop between them instead of recursion for dim,bit_mask in enumerate(bit_masks): index = int(bool(n & bit_mask)) # Either 0 or 1 weight *= dim_weights[dim][index] bin_list[dim] = dim_bins[dim][index] # Rearrange as an array of tuples # It might seem like a simple reshape would do the job but this # was the only way I found to make an array of tuples, as opposed # to the tuples being converted onto an extra dimension bin_list = [z for z in zip(*[b.flatten() for b in bin_list])] bin_list_aux = np.empty(np.prod(values_shape, dtype = int), dtype = 'O') bin_list_aux[:] = bin_list if values_shape: bin_list_aux = np.reshape(bin_list_aux, values_shape) bins[n] = bin_list_aux else: bins[n] = bin_list_aux[0] weights[n] = weight return bins, weights
def _get_axis_property(f): """ Decorator to retrieve a property from all axes at once The methods need to be reclared as: @_get_axis_property def property_name(self): return 'property_name' """ @property def wrapper(self): return np.array([getattr(axis, f(self)) for axis in self._axes]) return wrapper @_get_axis_property def units(self): """ Labels of all axes. """ return 'unit' @_get_axis_property def labels(self): """ Labels of all axes. """ return 'label' @labels.setter def labels(self, new_labels): if len(new_labels) != self.ndim: raise ValueError("Number of labels do not correspond to the " "number of dimensions.") for axis,label in zip(self,new_labels): axis.label = label self._update_labels_index() @_get_axis_property def lo_lims(self): """ Overall lower bounds """ return 'lo_lim' @_get_axis_property def hi_lims(self): """ Overall upper bounds """ return 'hi_lim' @_get_axis_property def nbins(self): """ Number of elements along each axis. """ return 'nbins'
[docs] def expand_dims(self, a, axis): """ Insert new axes into `a` as appropiate to allow broadcasting with a histogram having these axes """ # Standarize inputs if not isinstance(a, (np.ndarray, SparseArray)): a = np.array(a) axis = np.array(self.label_to_index(axis)) if axis.ndim == 0: # Make scalars an array axis = axis[None] # Sanity checks if a.ndim != len(axis): raise ValueError("Number of input axes ({}) " "does not match number of " "dimensions ({}) of the " "input array".format(len(axis), a.ndim)) if a.ndim > self.ndim: raise ValueError("Number of dimensions of the input array ({}) " "cannot be greater than the " "number of axes ({})".format(a.ndim, self.ndim)) if any(axis >= self.ndim) or any(axis < 0): raise ValueError("One or more axes indices ({}) " "out of bounds ({} axes)".format(axis, self.ndim)) # Match number of axes orig_ndim = a.ndim a = a[tuple(slice(None) for _ in range(a.ndim)) + tuple(None for _ in range(self.ndim - a.ndim))] # Reorganize axes a = np.moveaxis(a, np.arange(orig_ndim), axis) return a
[docs] def broadcast(self, a, axis): """ Expand the dimensions and broadcast an array for a given set of axes such that it has the same dimensions as the histogram. Args: a (array): Array to broadcast axis (int or array): Histogram axes correspond to the array axes. """ # Add singleton dims a = self.expand_dims(a, axis) # Account for under/overflow new_shape = self.nbins for i,(a_nbins,h_nbins) in enumerate(zip(a.shape, new_shape)): if a_nbins == h_nbins+2: new_shape[i] += 2 output = np.broadcast_to(a, tuple(new_shape)) if isinstance(a, u.Quantity): output = output*a.unit return output
[docs] def expand_dict(self, axis_value, default = None): """ Convert pairs of axis:value to a list of length `ndim`. Args: axis_value (dict): Dictionary with axis-value pairs (can be labels) default: Default filling value for unspecified axes Return: tuple """ val_list = [default] * self.ndim for axis,value in axis_value.items(): axis = self.label_to_index(axis) val_list[axis] = value return tuple(val_list)