# -*- coding: utf-8 -*-
"""
@author: Anna Petrasova
"""

import math
import numpy
from gradient import gradient

from grass.script import core as gcore


def get_region_info():
    region = gcore.region(region3d=True)
    for old, new in zip(['n', 's', 'e', 'w', 't', 'b'],
                        ['north', 'south', 'east', 'west', 'top', 'bottom']):
        region[new] = region.pop(old)
    return region


def get_velocity(rast3d_region, velocity_obj, point):
    if velocity_obj.compute_gradient:
        return get_gradient(rast3d_region, velocity_obj, point)
    else:
        return rast3d_trilinear_interpolation(rast3d_region, velocity_obj.vector_arrays,
                                              point[1], point[0], point[2])


#def get_velocity(rast3d_region, velocity_arrays, point):
#    """Returns tuple of interpolated velocity values
#    in xyz components for a given point, or None if outside of region"""
#    return rast3d_trilinear_interpolation(rast3d_region, velocity_arrays,
#                                          point[1], point[0], point[2])


def get_gradient(rast3d_region, velocity_obj, point):
    nearest = find_nearest_voxels(rast3d_region, point[1], point[0], point[2])

    minx, maxx = nearest[0][0], nearest[7][0]
    miny, maxy = nearest[7][1], nearest[0][1]
    minz, maxz = nearest[0][2], nearest[7][2]
    # position of x, y, z neighboring voxels
    neighbors_pos = [minx, miny, minz]
    if not velocity_obj.neighbors_pos or neighbors_pos != velocity_obj.neighbors_pos:
        velocity_obj.neighbors_pos = neighbors_pos

        # just to be sure, we check that at least one voxel is inside
        if maxx < 0 or minx >= rast3d_region['cols'] or \
           maxy < 0 or miny >= rast3d_region['rows'] or \
           maxz < 0 or minz >= rast3d_region['depths']:
            return None

        # these if's are here to handle edge cases
        # min, max are changed to represent the min max of the 4x4x4 array
        # from which the gradient will be computed
        # shift is relative position of the neighbors within this 4x4x4 array
        if minx == 0 or minx == -1:
            maxx = minx + 3 if minx == 0 else minx + 4
            xshift = minx
            minx = 0
        elif maxx >= rast3d_region['cols'] - 1:
            minx = maxx - 3 if maxx < rast3d_region['cols'] else maxx - 4
            xshift = 2 if maxx < rast3d_region['cols'] else 3
            maxx = rast3d_region['cols'] - 1
        else:
            minx -= 1
            maxx += 1
            xshift = 1

        if miny == 0 or miny == -1:
            maxy = miny + 3 if miny == 0 else miny + 4
            yshift = miny
            miny = 0
        elif maxy >= rast3d_region['rows'] - 1:
            miny = maxy - 3 if maxy < rast3d_region['rows'] else maxy - 4
            yshift = 2 if maxy < rast3d_region['rows'] else 3
            maxy = rast3d_region['rows'] - 1
        else:
            miny -= 1
            maxy += 1
            yshift = 1

        if minz == 0 or minz == -1:
            maxz = minz + 3 if minz == 0 else minz + 4
            zshift = minz
            minz = 0
        elif maxz >= rast3d_region['depths'] - 1:
            minz = maxz - 3 if maxz < rast3d_region['depths'] else maxz - 4
            zshift = 2 if maxz < rast3d_region['depths'] else 3
            maxz = rast3d_region['depths'] - 1
        else:
            minz -= 1
            maxz += 1
            zshift = 1

        # get the 4x4x4 block of the array
        array = velocity_obj.scalar_array[minz:maxz + 1, miny:maxy + 1, minx:maxx + 1]
        xyz_gradients = gradient(array, step=[rast3d_region['ewres'], rast3d_region['nsres'],
                                              rast3d_region['tbres']])
        # go through x, y, z and all 8 neighbors and store their value
        # if the voxel is outside, add 0 (weight)
        array_values = [[] for i in range(3)]
        for i in range(3):
            for z in range(2):
                for y in range(1, -1, -1):
                    for x in range(2):
                        if z + zshift < 0 or z + zshift > 3 or y + yshift < 0 or y + yshift > 3 \
                           or x + xshift < 0 or x + xshift > 3:
                            array_values[i].append(0)
                        else:
                            array_values[i].append(xyz_gradients[i][z + zshift][y + yshift][x + xshift])

        velocity_obj.neighbors_values = array_values
    else:
        array_values = velocity_obj.neighbors_values

    x, y, z = get_relative_coords_for_interp(rast3d_region, point[1], point[0], point[2])

    return trilinear_interpolation(array_values, x, y, z)


def norm(vector):
    return math.sqrt(vector[0] * vector[0] + vector[1] * vector[1] + vector[2] * vector[2])


def compute_gradient(array):
    """!Computes gradient of 3D array using numpy gradient function.
    Returns x, y, z components."""
    vz, vy, vx = numpy.gradient(array)
    # we have to reverse y because we have different system than numpy
    vy = -vy
    return vx, vy, vz

# following functions are already in raster3d library, this is simpler Python version
# except for trilinear interpolation which could be added there.


def rast3d_is_valid_location(rast3d_region, north, east, top):
    return (north >= rast3d_region['south'] and north <= rast3d_region['north']) and \
        (east >= rast3d_region['west'] and east <= rast3d_region['east']) and \
        (top >= rast3d_region['bottom'] and top <= rast3d_region['top'])


def rast3d_location2coord(rast3d_region, north, east, top):
    """!Returns column, row and depth of a cell where the input
    coordinates (north, east, top) lie."""
    # math.floor to get because int(-0.8) gives 0
    col = int(math.floor((east - rast3d_region['west']) / (rast3d_region['east'] -
                         rast3d_region['west']) * rast3d_region['cols']))
    row = int(math.floor((north - rast3d_region['south']) / (rast3d_region['north'] -
                         rast3d_region['south']) * rast3d_region['rows']))
    depth = int(math.floor((top - rast3d_region['bottom']) / (rast3d_region['top'] -
                           rast3d_region['bottom']) * rast3d_region['depths']))
    row = rast3d_region['rows'] - row - 1

    return col, row, depth


def rast3d_get_value_region(rast3d_region, array, col, row, depth):
    """Returns value of 3D raster array in particular column, row and depth.
    Returns None if it is outside of array."""
    if col < 0 or col >= rast3d_region['cols'] or \
       row < 0 or row >= rast3d_region['rows'] or \
       depth < 0 or depth >= rast3d_region['depths']:
        return None

    return array[depth][row][col]


def find_nearest_voxels(rast3d_region, north, east, top):
    """Finds 8 nearest voxel to given point and returns their indices."""
    reg = rast3d_region
    n_minus, n_plus = north - reg['nsres'] / 2., north + reg['nsres'] / 2.
    e_minus, e_plus = east - reg['ewres'] / 2., east + reg['ewres'] / 2.
    t_minus, t_plus = top - reg['tbres'] / 2., top + reg['tbres'] / 2.
    # find nearest cells
    points = [(n_minus, e_minus, t_minus),
              (n_minus, e_plus, t_minus),
              (n_plus, e_minus, t_minus),
              (n_plus, e_plus, t_minus),
              (n_minus, e_minus, t_plus),
              (n_minus, e_plus, t_plus),
              (n_plus, e_minus, t_plus),
              (n_plus, e_plus, t_plus)]
    coordinates = []
    for point in points:
        coordinates.append(rast3d_location2coord(reg, *point))

    return coordinates

#def rast3d_get_window_value(rast3d_region, array, north, east, top):
#    """Returns interpolated value in a point defined by geogr. coordinates"""
#    return rast3d_trilinear_interpolation(rast3d_region, array, north, east, top)


def rast3d_trilinear_interpolation(rast3d_region, arrays, north, east, top):
    """Linearly interpolates value of multiple 3D maps in a given point from the 8 closest cells
    based on the centers of cells."""
    reg = rast3d_region

    # get values of the nearest cells
    array_values = [[] for i in range(len(arrays))]
    for col, row, depth in find_nearest_voxels(reg, north, east, top):
        for i in range(len(arrays)):
            value = rast3d_get_value_region(reg, arrays[i], col, row, depth)
            if value is None:
                array_values[i].append(0)
            else:
                array_values[i].append(value)

    # compute weights
    col, row, depth = rast3d_location2coord(reg, north, east, top)
    # check if we are out of region, any array should work
    if rast3d_get_value_region(reg, arrays[0], col, row, depth) is None:
        return None

    x, y, z = get_relative_coords_for_interp(rast3d_region, north, east, top)
    return trilinear_interpolation(array_values, x, y, z)


def get_relative_coords_for_interp(rast3d_region, north, east, top):
    reg = rast3d_region
    col, row, depth = rast3d_location2coord(rast3d_region, north, east, top)
     # x
    temp = east - reg['west'] - col * reg['ewres']
    x = (temp - reg['ewres'] / 2. if temp > reg['ewres'] / 2. else temp + reg['ewres'] / 2.) / reg['ewres']
    # y
    temp = north - reg['south'] - (reg['rows'] - row - 1) * reg['nsres']
    y = (temp - reg['nsres'] / 2. if temp > reg['nsres'] / 2. else temp + reg['nsres'] / 2.) / reg['nsres']
    # z
    temp = top - reg['bottom'] - depth * reg['tbres']
    z = (temp - reg['tbres'] / 2. if temp > reg['tbres'] / 2. else temp + reg['tbres'] / 2.) / reg['tbres']

    return x, y, z


def trilinear_interpolation(array_values, x, y, z):
    rx = 1 - x
    ry = 1 - y
    rz = 1 - z
    weights = [rx * ry * rz,
               x * ry * rz,
               rx * y * rz,
               x * y * rz,
               rx * ry * z,
               x * ry * z,
               rx * y * z,
               x * y * z]

    # weighted of surrounding values
    interpolated_values = []
    for values in array_values:
        value = 0
        for i in range(len(weights)):
            value += weights[i] * values[i]
        interpolated_values.append(value)
    return interpolated_values
