#!/usr/bin/env python
############################################################################
#
# MODULE:       r.neighbors.mp
# AUTHOR(S):    Vaclav Petras
# PURPOSE:      Wrapper for r.neighbors for experimenting with parallelization
# COPYRIGHT:    (C) 2018 by Vaclav Petras, and the GRASS Development Team
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
############################################################################

#%module
#% description: Makes each cell category value a function of the category values assigned to the cells around it, and stores new cell values in an output raster map layer.
#% keyword: raster
#% keyword: algebra
#% keyword: statistics
#% keyword: aggregation
#% keyword: neighbor
#% keyword: focal statistics
#% keyword: filter
#%end
#%flag
#% key: a
#% description: Do not align output with the input
#%end
#%flag
#% key: c
#% description: Use circular neighborhood
#% guisection: Neighborhood
#%end
#%option G_OPT_R_INPUT
#%end
#%option G_OPT_R_INPUT
#% key: selection
#% required: no
#% description: Name of an input raster map to select the cells which should be processed
#%end
#%option G_OPT_R_OUTPUT
#% multiple: yes
#%end
#%option
#% key: method
#% type: string
#% required: no
#% multiple: yes
#% options: average,median,mode,minimum,maximum,range,stddev,sum,count,variance,diversity,interspersion,quart1,quart3,perc90,quantile
#% description: Neighborhood operation
#% answer: average
#% guisection: Neighborhood
#%end
#%option
#% key: size
#% type: integer
#% required: no
#% multiple: no
#% description: Neighborhood size
#% answer: 3
#% guisection: Neighborhood
#%end
#%option
#% key: title
#% type: string
#% required: no
#% multiple: no
#% key_desc: phrase
#% description: Title for output raster map
#%end
#%option G_OPT_F_INPUT
#% key: weight
#% required: no
#% description: Text file containing weights
#%end
#%option
#% key: gauss
#% type: double
#% required: no
#% multiple: no
#% description: Sigma (in cells) for Gaussian filter
#%end
#%option
#% key: quantile
#% type: double
#% required: no
#% multiple: yes
#% options: 0.0-1.0
#% description: Quantile to calculate for method=quantile
#%end
#%option
#% key: nprocs
#% type: integer
#% required: yes
#% description: Number of processes for parallel computing
#% guisection: Parallel
#%end
#%option
#% key: width
#% type: integer
#% required: yes
#% description: Width of a tile in cells
#% guisection: Parallel
#%end
#%option
#% key: height
#% type: integer
#% required: yes
#% description: Height of a tile in cells
#% guisection: Parallel
#%end
#%option
#% key: overlap
#% type: integer
#% required: yes
#% description: Overlap of tile in cells
#% guisection: Parallel
#%end


import os
import sys
import copy
from multiprocessing import Pool
import atexit
import uuid

from collections import namedtuple

import grass.script as gs
from grass.exceptions import CalledModuleError


def remove_keys_from_dict(dictionary, keys):
    for key in keys:
        del dictionary[key]


def remove_false_values():
    return {key: value for key, value in dictionary.items() if value}


def flags_dict_to_str(flags)
    return "".join([value for key, value in flags.items() if value])


Region = namedtuple(
    'Region', 'west, east, north, south, rows, cols, ewres, nsres')


def get_current_region():
    reg = gs.parse_command('g.region', flags='g3')
    return Region(west=float(reg['w']), east=float(reg['e']),
                  north=float(reg['n']), south=float(reg['s']),
                  rows=int(reg['rows']), cols=int(reg['cols']),
                  ewres=float(reg['ewres']), nsres=float(reg['nsres']))


def get_tempory_region(do_cleanup=True):
    """Copies the current region to a temporary region with "g.region save=",
    then returns name of that region. Installs an atexit
    handler to delete the temporary region upon termination by default.
    """
    name = "tmp.%s.%s" % (os.path.basename(sys.argv[0]),
                          str(uuid.uuid4()).replace('-', ''))
    gs.run_command('g.region', save=name, overwrite=True)
    if do_cleanup:
        atexit.register(delete_temporary_region, name)
    return name


def delete_temporary_region(name):
    """Removes any region named by it."""
    try:
        gs.run_command('g.remove', type='region', name=name,
                       flags='f', quiet=True)
    except CalledModuleError:
        pass


BBox = namedtuple('BBox', 'west, east, north, south')


def region_to_tiles(width, height, overlap=0, region=None):
    """

    >>> reg = Region()
    >>> reg.north = 1000
    >>> reg.south = 0
    >>> reg.nsres = 1
    >>> reg.east = 2000
    >>> reg.west = 0
    >>> reg.ewres = 1
    >>> reg.cols
    2000
    >>> reg.rows
    1000
    >>> tiles = region_to_tiles(500, 500, 10, reg)
    >>> len(tiles)
    8
    """
    lists = split_region_tiles(region=region, width=width,
                               height=height, overlap=overlap)
    return [item for sublist in lists for item in sublist]


# copy of from grass.pygrass.modules.grid.split import split_region_tiles get_bbox
def get_bbox(reg, row, col, width, height, overlap):
    """Return a Bbox

    :param reg: a Region object to split
    :type reg: Region object
    :param row: the number of row
    :type row: int
    :param col: the number of row
    :type col: int
    :param width: the width of tiles
    :type width: int
    :param height: the width of tiles
    :type height: int
    :param overlap: the value of overlap between tiles
    :type overlap: int
    """
    north = reg.north - (row * height - overlap) * reg.nsres
    south = reg.north - ((row + 1) * height + overlap) * reg.nsres
    east = reg.west + ((col + 1) * width + overlap) * reg.ewres
    west = reg.west + (col * width - overlap) * reg.ewres
    return BBox(north=north if north <= reg.north else reg.north,
                south=south if south >= reg.south else reg.south,
                east=east if east <= reg.east else reg.east,
                west=west if west >= reg.west else reg.west,)
                

# copy of from grass.pygrass.modules.grid.split import split_region_tiles
def split_region_tiles(region, width=100, height=100, overlap=0):
    """Spit a region into a list of bounding boxes"""
    ncols = (region.cols + width - 1) // width
    nrows = (region.rows + height - 1) // height
    box_list = []
    for row in range(nrows):
        row_list = []
        for col in range(ncols):
            #print 'c', c, 'r', r
            row_list.append(get_bbox(region, row, col, width, height, overlap))
        box_list.append(row_list)
    return box_list


def dict_parse_command(*args, **kwargs):
    return dict(gs.parse_command(*args, **kwargs))


# this is basically a list of functors which has a special way of adding
# more items to it
class ModuleCallList(list):

    def __init__(self, bbox=None):
        super(ModuleCallList, self).__init__()
        self.env = None
        if bbox:
            self.bbox = bbox
            self.env = os.environ.copy()
            self.env['WIND_OVERRIDE'] = get_tempory_region()
            bbox = self.bbox
            gs.run_command('g.region', w=bbox.west, e=bbox.east,
                           n=bbox.north, s=bbox.south, env=self.env)

    def run_command(self, *args, **kwargs):
        # TODO: what should happen when env is provided?
        if 'env' not in kwargs:
            kwargs['env'] = self.env
        self.append(ModuleCall(gs.run_command, *args, **kwargs))

    def read_command(self, *args, **kwargs):
        if 'env' not in kwargs:
            kwargs['env'] = self.env
        self.append(ModuleCall(gs.read_command, *args, **kwargs))

    def parse_command(self, *args, **kwargs):
        if 'env' not in kwargs:
            kwargs['env'] = self.env
        # TODO: only normal dict can get pickled, not the GRASS KeyVal
        self.append(ModuleCall(dict_parse_command, *args, **kwargs))

    def mapcalc(self, *args, **kwargs):
        if 'env' not in kwargs:
            kwargs['env'] = self.env
        self.append(ModuleCall(gs.mapcalc, *args, **kwargs))

    def mapcalc3d(self, *args, **kwargs):
        if 'env' not in kwargs:
            kwargs['env'] = self.env
        self.append(ModuleCall(gs.mapcalc3d, *args, **kwargs))

    def execute(self):
        """Execute all queued processes.

        Return whatever was the last return value
        """
        ret = None
        for call in self:
            ret = call.execute()
        return ret


# this is basically a functor
class ModuleCall(object):
    def __init__(self, function, *args, **kwargs):
        self.function = function
        self.args = args
        self.kwargs = kwargs

    def execute(self):
        """Execute all associated process.

        Returns the original return value
        """
        return self.function(*(self.args), **(self.kwargs))


# TODO: do we need the functions twice?
def execute_module_call_list(module_list):
    try:
        return module_list.execute()
    except (KeyboardInterrupt, CalledModuleError):
        return


def execute_module_call(call):
    try:
        return call.execute()
    except (KeyboardInterrupt, CalledModuleError):
        return


# TODO: introduce better naming for modules and functions
# TODO: introduce debug, dry_run and no_clean_tmp wherever appropriate
def execute_list(module_lists, nprocs, debug=False):
    if debug:
        result = []
        for module_list in module_lists:
            result.append(module_list.execute())
        return result
    pool = Pool(nprocs)
    proc = pool.map_async(execute_module_call_list, module_lists)
    # TODO: error handling should be improved
    try:
        return proc.get()
    except (KeyboardInterrupt, CalledModuleError):
        # TODO: this should be done in a better way
        sys.exit("KeyboardInterrupt, CalledModuleError")


# TODO: errors lead to strange behavior, is legal to wrap map_async like this?
def execute_function_on_list(module_lists, function, nprocs, debug=False):
    if debug:
        result = []
        for module_list in module_lists:
            result.append(function(module_list))
        return result
    pool = Pool(nprocs)
    proc = pool.map_async(function, module_lists)
    # TODO: error handling should be improved
    try:
        return proc.get()
    except (KeyboardInterrupt, CalledModuleError):
        # TODO: this should be done in a better way
        sys.exit("KeyboardInterrupt, CalledModuleError")


def execute_by_module(modules, nprocs):
    pool = Pool(nprocs)
    proc = pool.map_async(execute_module_call, modules)
    try:
        return proc.get()
    except (KeyboardInterrupt, CalledModuleError):
        # TODO: this should be done in a better way
        sys.exit("KeyboardInterrupt, CalledModuleError")


def generate_patch_expression(rasters, bboxes):
    exprs = []
    expr = ("if(x() >= {b.west} && x() <= {b.east} "
            "&& y() >= {b.south} && y() <= {b.north}, {m},")
    for raster, bbox in zip(rasters, bboxes):
        exprs.append(expr.format(b=bbox, m=raster))
    return " ".join(exprs) + " null()" + ")" * len(exprs)

# TODO: patch looses color table

def patch_rasters(rasters, bboxes, output):
    expression = generate_patch_expression(rasters, bboxes)
    gs.mapcalc("%s = %s" % (output, expression))


def patch_rasters3d(rasters, bboxes, output):
    expression = generate_patch_expression(rasters, bboxes)
    gs.mapcalc3d("%s = %s" % (output, expression))


def patch_rasters_2(args):
    patch_rasters(*args)


def patch_rasters3d_2(args):
    patch_rasters3d(*args)


def remove_rasters(maps):
    gs.run_command('g.remove', type='raster', name=maps,
                   flags='f', quiet=True)


def remove_rasters3d(maps):
    gs.run_command('g.remove', type='raster3d', name=maps,
                   flags='f', quiet=True)


class Namer(object):
    def __init__(self, bbox, number, callback):
        self.bbox = bbox
        self.number = number  # TODO: this is a propertyand needs a better name
        self.callback = callback

    def name(self, type, name, patch=True):
        tile_name = "%s_%d" % (name, self.number)
        self.callback(type, name, tile_name, self.bbox)
        return tile_name


# this idea for the class is that it is using the general functions
class TiledWorkflow(object):

    def __init__(self, nprocs, width, height, overlap=0, region=None, debug=False):
        self.nprocs = nprocs
        self.region = get_current_region()
        self.bboxes = region_to_tiles(
            region=self.region, width=width, height=height,
            overlap=overlap)
        self.bboxes_no_overlap = region_to_tiles(
            region=self.region, width=width, height=height)
        self.lists_of_rasters = []  # tuple(list, list, str)
        self.lists_of_rasters3d = []  # tuple(list, list, str)
        self.elements_to_patch = {}
        self.calls = []
        self.debug = debug
        self.iteration_index = 0
        self.iteration_max = len(self.bboxes)

    def __iter__(self):
        return self

    def next(self):
        if self.iteration_index >= self.iteration_max:
            raise StopIteration
        else:
            bbox = self.bboxes[self.iteration_index]
            bbox_no_overlap = self.bboxes_no_overlap[self.iteration_index]
            namer = Namer(bbox=bbox_no_overlap, number=self.iteration_index,
                          callback=self.add_element_to_patch)
            call = ModuleCallList(bbox=bbox)
            self.calls.append(call)
            self.iteration_index += 1
            return namer, call

    def execute_tiled(self):
        return execute_list(self.calls, nprocs=self.nprocs, debug=self.debug)

    def add_rasters_to_patch(self, inputs, output):
        if len(self.bboxes) != len(inputs):
            raise ValueError("Number of items to patch must be the same"
                             " as the number of bounding boxes")
        self.lists_of_rasters.append((inputs, self.bboxes_no_overlap, output))

    def add_rasters3d_to_patch(self, inputs, output):
        if len(self.bboxes) != len(inputs):
            raise ValueError("Number of items to patch must be the same"
                             " as the number of bounding boxes")
        self.lists_of_rasters3d.append((inputs, self.bboxes_no_overlap, output))

    def add_element_to_patch(self, type, name, tile_name, bbox):
        if not type in self.elements_to_patch:
            self.elements_to_patch[type] = {}
        if not name in self.elements_to_patch[type]:
            self.elements_to_patch[type][name] = {'tiles': [], 'bboxes': []}
        self.elements_to_patch[type][name]['tiles'].append(tile_name)
        self.elements_to_patch[type][name]['bboxes'].append(bbox)

    def patch(self):
        for type, names in self.elements_to_patch.iteritems():
            if type == 'raster':
                for name, info in names.iteritems():
                    self.lists_of_rasters.append((info['tiles'], info['bboxes'], name))
        
        execute_function_on_list(self.lists_of_rasters, function=patch_rasters_2, nprocs=self.nprocs, debug=self.debug)
        execute_function_on_list(self.lists_of_rasters3d, function=patch_rasters3d_2, nprocs=self.nprocs, debug=self.debug)
        if self.debug:
            patch_rasters(range(0, len(self.bboxes_no_overlap)), self.bboxes_no_overlap, 'tiles_no_overlaps')
            patch_rasters(range(0, len(self.bboxes)), self.bboxes, 'tiles_with_overlaps')
        # TODO: parallelize
        # TODO: perhaps special option for temp files, or even preserve tiles
        if not self.debug:
            for rasters, bbox, output in self.lists_of_rasters:
                remove_rasters(rasters)
            for rasters, bbox, output in self.lists_of_rasters3d:
                remove_rasters3d(rasters)

    def execute(self):
        result = self.execute_tiled()
        self.patch()
        return result


def main(options, flags):
    input = options['input']
    outputs = options['output']
    nprocs = int(options['nprocs'])
    width = int(options['width'])
    height = int(options['height'])
    overlap = int(options['overlap'])

    # TODO: theoretically we could also just take nprocs
    # and do some guesses for the rest (including overlap)
    # TODO: check existence of input and output beforehand

    # shallow copy to just pass the remaining options
    neighbors_options = copy.copy(options)
    # remove options for this module
    this_module_opts = [
        'input', 'output', 'nprocs', 'width', 'height', 'overlap']
    remove_key_from_dict(neighbors_options, this_module_opts)
    # pass only options which are set
    neighbors_options = remove_false_values(neighbors_options)

    # pass active flags to the module
    neighbors_flags = flags_dict_to_str(flags)

    # TODO: Align output with the input (opposite of -a) is not
    # supported, probably a g.region call needed here

    tiled_workflow = TiledWorkflow(nprocs=nprocs, width=width,
                                   height=height, overlap=overlap)

    gs.message(_("Splitting processing into {} tiles").format(
        len(tiled_workflow.bboxes)))

    for namer, workflow in tiled_workflow:
        # get name for each of the (potential) outputs
        outs = []
        for output in outputs.split(','):
            outs.append(namer.name('raster', output))

        workflow.run_command('r.neighbors', input=input,
                             output=outs, quiet=True,
                             flags=neighbors_flags, **neighbors_options)

    results = tiled_workflow.execute()

    # TODO: title for output not supported
    # TODO: proper history not provided


if __name__ == "__main__":
    sys.exit(main(*gs.parser()))
