# (C) Copyright 2007 
# Andreas Kloeckner <inform -at- tiker.net>
#
# Use, modification and distribution is subject to the Boost Software
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
#
#  Authors: Andreas Kloeckner




import boost.mpi as mpi
import random
import sys

MAX_GENERATIONS = 20
TAG_DEBUG = 0
TAG_DATA = 1
TAG_TERMINATE = 2
TAG_PROGRESS_REPORT = 3




class TagGroupListener:
    """Class to help listen for only a given set of tags.

    This is contrived: Typicallly you could just listen for 
    mpi.any_tag and filter."""
    def __init__(self, comm, tags):
        self.tags = tags
        self.comm = comm
        self.active_requests = {}

    def wait(self):
        for tag in self.tags:
            if tag not in self.active_requests:
                self.active_requests[tag] = self.comm.irecv(tag=tag)
        requests = mpi.RequestList(self.active_requests.values())
        data, status, index = mpi.wait_any(requests)
        del self.active_requests[status.tag]
        return status, data

    def cancel(self):
        for r in self.active_requests.itervalues():
            r.cancel()
            #r.wait()
        self.active_requests = {}



def rank0():
    sent_histories = (mpi.size-1)*15
    print "sending %d packets on their way" % sent_histories
    send_reqs = mpi.RequestList()
    for i in range(sent_histories):
        dest = random.randrange(1, mpi.size)
        send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))

    mpi.wait_all(send_reqs)

    completed_histories = []
    progress_reports = {}
    dead_kids = []

    tgl = TagGroupListener(mpi.world,
            [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])

    def is_complete():
        for i in progress_reports.values():
            if i != sent_histories:
                return False
        return len(dead_kids) == mpi.size-1

    while True:
        status, data = tgl.wait()

        if status.tag == TAG_DATA:
            #print "received completed history %s from %d" % (data, status.source)
            completed_histories.append(data)
            if len(completed_histories) == sent_histories:
                print "all histories received, exiting"
                for rank in range(1, mpi.size):
                    mpi.world.send(rank, TAG_TERMINATE, None)
        elif status.tag == TAG_PROGRESS_REPORT:
            progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
        elif status.tag == TAG_DEBUG:
            print "[DBG %d] %s" % (status.source, data)
        elif status.tag == TAG_TERMINATE:
            dead_kids.append(status.source)
        else:
            print "unexpected tag %d from %d" % (status.tag, status.source)

        if is_complete():
            break

    print "OK"

def comm_rank():
    while True:
        data, status = mpi.world.recv(return_status=True)
        if status.tag == TAG_DATA:
            mpi.world.send(0, TAG_PROGRESS_REPORT, data)
            data.append(mpi.rank)
            if len(data) >= MAX_GENERATIONS:
                dest = 0
            else:
                dest = random.randrange(1, mpi.size)
            mpi.world.send(dest, TAG_DATA, data)
        elif status.tag == TAG_TERMINATE:
            from time import sleep
            mpi.world.send(0, TAG_TERMINATE, 0)
            break
        else:
            print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)


def main():
    # this program sends around messages consisting of lists of visited nodes
    # randomly. After MAX_GENERATIONS, they are returned to rank 0.

    if mpi.rank == 0:
        rank0()
    else:
        comm_rank()
        


if __name__ == "__main__":
    main()
