Source code for tike.communicators.mpi

"""Define a MPI wrapper for inter-node communications.."""

__author__ = "Xiaodong Yu"
__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC."
__docformat__ = 'restructuredtext en'

from mpi4py import MPI

import numpy as np


[docs]class MPIComm: """A class for python MPI wrapper. Many clusters do not support inter-node GPU-GPU communications, so we first gather the data into main memory then communicate them. Attributes ---------- rank : int The identity of this process. size : int The total number of MPI processes. """ def __init__(self): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() def __enter__(self): return self def __exit__(self, type, value, traceback): pass
[docs] def p2p(self, sendbuf, src=0, dest=1, tg=0, **kwargs): """Send data from a source to a designated destination.""" if sendbuf is None: raise ValueError(f"Sendbuf can't be empty.") if self.rank == src: self.comm.Send(sendbuf, dest=dest, tag=tg, **kwargs) elif self.rank == dest: info = MPI.Status() recvbuf = np.empty(sendbuf.shape, sendbuf.dtype) self.comm.Recv(recvbuf, source=src, tag=tg, status=info, **kwargs) return recvbuf
[docs] def Bcast(self, data, root: int = 0): """Send data from a root to all processes.""" if data is None: raise ValueError(f"Broadcast data can't be empty.") if self.rank == root: data = data else: data = np.empty(data.shape, data.dtype) self.comm.Bcast(data, root) return data
[docs] def Gather(self, sendbuf, dest: int = 0): """Take data from all processes into one destination.""" if sendbuf is None: raise ValueError(f"Gather data can't be empty.") if self.rank == dest: recvbuf = np.empty(sendbuf.shape, sendbuf.dtype) self.comm.Scatter(sendbuf, recvbuf, dest) if self.rank == dest: return recvbuf
[docs] def Scatter(self, sendbuf, src: int = 0): """Spread data from a source to all processes.""" if sendbuf is None: raise ValueError(f"Scatter data can't be empty.") recvbuf = np.empty(sendbuf.shape, sendbuf.dtype) self.comm.Scatter(sendbuf, recvbuf, src) return recvbuf
[docs] def Allreduce(self, sendbuf, op=MPI.SUM): """Combines data from all processes and distributes the result back to all processes.""" if sendbuf is None: raise ValueError(f"Allreduce data can't be empty.") recvbuf = np.empty(sendbuf.shape, sendbuf.dtype) self.comm.Allreduce(sendbuf, recvbuf, op=op) return recvbuf
[docs] def MPIio(self, scan, data): """Read data parts to different processes.""" # Determine the edges of the stripes edges = np.linspace( scan[..., 0].min(), scan[..., 0].max(), self.size + 1, endpoint=True, ) # Move the outer edges to include all points edges[0] -= 1 edges[-1] += 1 # Generate the mask mask = np.logical_and( edges[self.rank] < scan[0, :, 0], scan[0, :, 0] <= edges[self.rank + 1]) scan = scan[:, mask] data = data[:, mask] return scan, data