#!/usr/bin/env python
# -*- coding: utf-8 -*-
############################################################################
#
# MODULE:       r3.flow
# AUTHOR:       Anna Petrasova
# PURPOSE:
#
#
#
# COPYRIGHT:    (c) 2014 the GRASS Development Team
#               This program is free software under the GNU General Public
#               License (>=v2). Read the file COPYING that comes with GRASS
#               for details.
#
#		This program is distributed in the hope that it will be useful,
#		but WITHOUT ANY WARRANTY; without even the implied warranty of
#		MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#		GNU General Public License for more details.
#
#############################################################################
#%module
#% description: Generate 3D flow lines from 3D raster representing velocity field
#% keywords: raster3d
#% keywords: vector
#%end
#%option G_OPT_R3_INPUT
#% key: input
#% required: no
#% description: Name of 3D raster maps
#%end
#%option G_OPT_R3_INPUTS
#% key: vector_field
#% required: no
#% description: Names of three 3D raster maps describing x, y, z components of vector field
#%end
#%option G_OPT_V_INPUT
#% key: seed_points
#% required: no
#% description: If no map is provided, flow lines are generated from each cell of the input 3D raster
#% label: Name of vector map with points from which flow lines are generated
#%end
#%option G_OPT_V_OUTPUT
#% key: flowline
#% required: no
#% description: Name of vector map of flow lines
#%end
#%option G_OPT_R3_OUTPUT
#% key: flowaccumulation
#% required: no
#% description: Name of the output flow accumulation 3D raster map
#%end
#%option
#% key: unit
#% type: string
#% required: no
#% answer: cell
#% options: time,length,cell
#% descriptions: elapsed time,length in map units,length in cells (voxels)
#% label: Unit of integration step
#% description: Default unit is cell
#% gisprompt: Integration
#%end
#%option
#% key: step
#% type: double
#% required: no
#% label: Integration step in selected unit
#% description: Default step is 0.25 cell
#% gisprompt: Integration
#%end
#%option
#% key: limit
#% type: integer
#% required: no
#% answer: 2000
#% description: Maximum number of steps
#% gisprompt: Integration
#%end
#%option
#% key: skip
#% type: integer
#% required: no
#% multiple: yes
#% description: Number of cells between flow lines in x, y and z direction
#%end
#%flag
#% key: u
#% description: Compute upstream flowlines instead of default downstream flowlines
#%end

import math
import numpy as np

from grass.script import core as gcore
from grass.script import array as garray
from grass.pygrass.vector import VectorTopo
from grass.pygrass.vector.geometry import Line, Point

from integrate import rk4_integrate_next, rk45_integrate_next, get_time_step, MIN_STEP, MAX_STEP
from rast3d_functions import get_velocity, norm, rast3d_location2coord, get_region_info
from voxel_traversal import traverse

EPSILON = 1e-8
RK45 = True  # have it as an option?


class Seed:
    def __init__(self, x, y, z, flowline, flowaccum):
        self.x = x
        self.y = y
        self.z = z
        self.flowline = flowline
        self.flowaccum = flowaccum


class Velocity:
    def __init__(self):
        self.compute_gradient = False
        self.vector_arrays = None
        self.scalar_array = None
        self.neighbors_values = None
        self.neighbors_pos = None


class Integrate:
    def __init__(self):
        self.direction = None
        self.unit = None
        self.step = None
        self.cell_size = None
        self.limit = None


def compute_flowline(map_info, seed, velocity_obj, integration, flowline_vector, flowacc):
    count = 1
    point = seed.x, seed.y, seed.z
    if seed.flowline:
        line = Line([Point(seed.x, seed.y, seed.z)])
    last_coords = (None, None, None)
    new_point = [None, None, None]
    while count <= integration.limit:
        velocity_vector = get_velocity(map_info, velocity_obj, point)
        if velocity_vector is None:
            break  # outside region
        velocity_norm = norm(velocity_vector)

        if velocity_norm <= EPSILON:
            break  # zero velocity means end of propagation
        # convert to time
        delta_T = get_time_step(integration.unit, integration.step, velocity_norm, integration.cell_size)
        # decide which integration method to choose
        if not RK45:
            ret = rk4_integrate_next(map_info, velocity_obj, point, new_point,
                                     delta_T * integration.direction)
        else:
            min_step = get_time_step('cell', MIN_STEP, velocity_norm, integration.cell_size)
            max_step = get_time_step('cell', MAX_STEP, velocity_norm, integration.cell_size)
            ret = rk45_integrate_next(map_info, velocity_obj, point, new_point,
                                      delta_T * integration.direction, min_step, max_step)
        if not ret:
            break

        if seed.flowline:
            line.append(Point(new_point[0], new_point[1], new_point[2]))
        if seed.flowaccum:
            col, row, depth = rast3d_location2coord(map_info, new_point[1], new_point[0], new_point[2])
            if last_coords != (col, row, depth):
                flowacc[depth][row][col] += 1
                if last_coords[0] is not None and \
                   sum([abs(coord[0] - coord[1]) for coord in zip(last_coords, (col, row, depth))]) > 1:
                    additional = traverse(point, new_point, map_info)
                    for p in additional:
                        flowacc[p[2]][p[1]][p[0]] += 1
                last_coords = (col, row, depth)
        point = new_point[:]
        count += 1

    if seed.flowline and len(line) > 1:
        flowline_vector.write(line)
        gcore.verbose(_("Flowline ended after %s steps") % (count - 1))


def main():
    options, flags = gcore.parser()

    velocity_obj = Velocity()
    integration = Integrate()
    if options['vector_field']:
        try:
            v_x, v_y, v_z = options['vector_field'].split(',')
        except ValueError:
            gcore.fatal(_("Please provide 3 input 3D raster maps representing components of vector field."))
        # read input 3D maps
        velocity_arrays = [garray.array3d(), garray.array3d(), garray.array3d()]
        for inp, vel in zip((v_x, v_y, v_z), velocity_arrays):
            if vel.read(mapname=inp) != 0:
                gcore.fatal(_("Error when reading input 3D raster maps"))
        velocity_obj.vector_arrays = velocity_arrays
        velocity_obj.compute_gradient = False

    elif options['input']:
        scalar_array = garray.array3d()
        scalar_array.read(mapname=options['input'])
        velocity_obj.compute_gradient = True
        velocity_obj.scalar_array = scalar_array
    else:
        # this should be handled by parser (in the future)
        gcore.fatal("Either input or vector_field options must be specified.")

    if not (options['flowaccumulation'] or options['flowline']):
        gcore.fatal("At least one of output options flowline or flowaccumulation must be specified.")

    map_info = get_region_info()

    if options['skip']:
        try:
            skipx, skipy, skipz = [int(each) for each in options['skip'].split(',')]
        except ValueError:
            gcore.fatal(_("Please provide 3 values for skip option"))
    else:
        skipx = max(1, int(map_info['cols'] / 10))
        skipy = max(1, int(map_info['rows'] / 10))
        skipz = max(1, int(map_info['depths'] / 10))

    # initialize integration variables

    # integrating forward means upstream, backward means downstream
    if flags['u']:
        integration.direction = 1
    else:
        integration.direction = -1

    integration.limit = int(options['limit'])
    # see which units to use
    step = options['step']
    if step:
        integration.step = float(step)
        integration.unit = options['unit']
    else:
        integration.unit = 'cell'
        integration.step = 0.25
    # cell size is the diagonal
    integration.cell_size = math.sqrt(map_info['nsres'] ** 2 + map_info['ewres'] ** 2 +
                                      map_info['tbres'] ** 2)

    if options['seed_points'] and options['skip']:
        gcore.warning(_("Option skip is ignored because seed point map was provided."))

    # create new vector map to store flow lines
    if options['flowline']:
        flowline_vector = VectorTopo(options['flowline'])
        if flowline_vector.exist() and not gcore.overwrite:
            gcore.fatal(_("Vector map <{flowline}> already exists.").format(flowline=flowline_vector))
        flowline_vector.open(mode='w', with_z=True)

    # create the flowaccumulation 3D raster
    if options['flowaccumulation']:
        flowacc = garray.array3d()
        for depth in range(map_info['depths']):
            flowacc[depth] = np.zeros((map_info['rows'], map_info['cols']))
    else:
        flowacc = None

    # try open seed points map and get number of points
    if options['seed_points']:
        inp_seeds = VectorTopo(options['seed_points'])
        inp_seeds.open()
        if not inp_seeds.is_3D():
            inp_seeds.close()
            gcore.fatal(_("Input vector map of seed points must be 3D."))
        num_points = inp_seeds.num_primitive_of('point')
        if num_points <= 0:
            inp_seeds.close()
            gcore.fatal(_("No points found in seed vector map."))

    # compute number of total seeds to show percent
    num_seeds = 0
    if options['seed_points']:
        num_seeds += num_points
    if options['flowaccumulation'] or (options['flowline'] and not options['seed_points']):
        if options['flowaccumulation']:
            num_seeds += map_info['cols'] * map_info['rows'] * map_info['depths']
        else:
            num_seeds += (math.ceil(map_info['cols'] / float(skipx)) *
                          math.ceil(map_info['rows'] / float(skipy)) *
                          math.ceil(map_info['depths'] / float(skipz)))
    gcore.info(_("Flow lines for {count} seed points will be computed.").format(count=num_seeds))

    # compute flowlines from vector seed points
    seed_count = 0
    if options['seed_points']:
        for point in inp_seeds:
            gcore.percent(seed_count, num_seeds, 1)
            seed = Seed(point.x, point.y, point.z, flowaccum=False, flowline=True)
            compute_flowline(map_info, seed, velocity_obj, integration, flowline_vector, flowacc)
            seed_count += 1

        inp_seeds.close()

    # compute flowlines from seeds created on grid
    if options['flowaccumulation'] or (options['flowline'] and not options['seed_points']):
        for r in range(map_info['rows'], 0, -1):
            for c in range(0, map_info['cols']):
                for d in range(0, map_info['depths']):
                    x = map_info['west'] + c * map_info['ewres'] + map_info['ewres'] / 2
                    y = map_info['south'] + r * map_info['nsres'] - map_info['nsres'] / 2
                    z = map_info['bottom'] + d * map_info['tbres'] + map_info['tbres'] / 2
                    seed = Seed(x, y, z, False, False)
                    if options['flowaccumulation']:
                        seed.flowaccum = True
                    if options['flowline'] and c % skipx == 0 and r % skipy == 0 and d % skipz == 0:
                        seed.flowline = True
                    if seed.flowaccum or seed.flowline:
                        gcore.percent(seed_count, num_seeds, 1)
                        compute_flowline(map_info, seed, velocity_obj, integration, flowline_vector, flowacc)
                        seed_count += 1

    if options['flowline']:
        flowline_vector.close()
    if options['flowaccumulation']:
        flowacc.write(mapname=options['flowaccumulation'])

if __name__ == '__main__':
    main()
