# -*- coding: utf-8 -*-
"""
@brief Runge-Kutta integration methods (RK4, RK45 with adaptive step size)
@author: Anna Petrasova
"""
import sys

from rast3d_functions import get_velocity, rast3d_is_valid_location, norm

# constants for adaptive step size integration
# min and max steps given in 'cell' unit
MIN_STEP = 0.01
MAX_STEP = 1.
# max error to compare with estimated error coming from RK45
MAX_ERROR = 1.0e-6


# nice explanation of runge-kutta with picture:
# http://www.marekfiser.com/Projects/Vector-field-visualization-on-GPU-using-CUDA/
#4-Vector-field-integrators-for-stream-line-visualization
def rk4_integrate_next(rast3d_region, velocity_obj, point, next_point, delta_t):
    """Integrates new point of a flowline using 4th order Runge-Kutta method.

    :param dict rast3d_region: information about region
    :param velocity_obj: velocity class to handle different sources of velocity field
    :param list point: list of x, y, z coordinates from where to integrate
    :param list next_point: list of x, y, z coordinates of the integrated point
    (computed here and the list is modified)
    :param float delta_t: time step for integration
    :return: True if next point is found, False if not (out of region/3D raster)
    """
    k1 = [None] * 3
    k2 = [None] * 3
    k3 = [None] * 3
    k4 = [None] * 3
    x, y, z = point[0], point[1], point[2]
    # first
    velocity = get_velocity(rast3d_region, velocity_obj, point)
    if velocity is None:
        return False
    for i in range(3):
        k1[i] = delta_t * velocity[i]
    # second
    velocity = get_velocity(rast3d_region, velocity_obj,
                            (x + k1[0] / 2, y + k1[1] / 2, z + k1[2] / 2))
    if velocity is None:
        return False
    for i in range(3):
        k2[i] = delta_t * velocity[i]
    # third
    velocity = get_velocity(rast3d_region, velocity_obj,
                            (x + k2[0] / 2, y + k2[1] / 2, z + k2[2] / 2))
    if velocity is None:
        return False
    for i in range(3):
        k3[i] = delta_t * velocity[i]
    # fourth
    velocity = get_velocity(rast3d_region, velocity_obj, (x + k3[0], y + k3[1], z + k3[2]))
    if velocity is None:
        return False
    for i in range(3):
        k4[i] = delta_t * velocity[i]

    # next point
    for i in range(3):
        next_point[i] = point[i] + k1[i] / 6 + k2[i] / 3 + k3[i] / 3 + k4[i] / 6
    if rast3d_is_valid_location(rast3d_region, next_point[1], next_point[0], next_point[2]):
        return True

    return False

# Runge-Kutta with adaptive step size
# adapted from vtkRungeKutta45 class implementation.

# Cash-Karp parameters
# http://en.wikipedia.org/wiki/Cash-Karp_method
B = [[1. / 5, 0, 0, 0, 0],
     [3. / 40, 9. / 40, 0, 0, 0],
     [3. / 10, -9. / 10, 6. / 5, 0, 0],
     [-11. / 54, 5. / 2, -70. / 27, 35. / 27, 0],
     [1631. / 55296, 175. / 512, 575. / 13824, 44275. / 110592, 253. / 4096]
     ]
C = [37. / 378, 0, 250. / 621, 125. / 594, 0, 512. / 1771]

DC = [37. / 378 - 2825. / 27648, 0,
      250. / 621 - 18575. / 48384, 125. / 594 - 13525. / 55296,
      -277. / 14336, 512. / 1771 - 1. / 4]


def rk45_integrate_next(rast3d_region, velocity_obj, point, next_point, delta_t, min_step, max_step):
    """Integrates new point of a flowline using Runge-Kutta method with adaptive step size.
    Uses Cash-Karp coefficients. Tries to decrease the estimated error below max error.

    :param dict rast3d_region: information about region
    :param velocity_obj: velocity class to handle different sources of velocity field
    :param list point: list of x, y, z coordinates from where to integrate
    :param list next_point: list of x, y, z coordinates of the integrated point
    (computed here and the list is modified)
    :param float delta_t: time step for integration
    :param float min_step: min time step for integration
    :param float max_step: max time step for integration
    :return: True if next point is found, False if not (out of region/3D raster)
    """
    estimated_error = sys.float_info.max

    error = [0] * 3
    # check if min_step < delta_t < max_step, what if not?

    # try to iteratively decrease error to less than max error
    while estimated_error > MAX_ERROR:
        # compute next point and get estimated error
        if rk45_next(rast3d_region, velocity_obj, point, next_point, delta_t, error):
            estimated_error = norm(error)
        else:
            return False
        # compute new step size (empirically)
        error_ratio = estimated_error / MAX_ERROR
        if error_ratio == 0.0:
            tmp = min_step if delta_t > 0 else -min_step
        elif error_ratio > 1:
            tmp = 0.9 * delta_t * pow(error_ratio, -0.25)
        else:
            tmp = 0.9 * delta_t * pow(error_ratio, -0.2)
        tmp2 = abs(tmp)

        do_break = False
        # adjust new step size to be within min max limits
        if tmp2 > max_step:
            delta_t = max_step * (1 if delta_t > 0 else -1)
            do_break = True
        elif tmp2 < min_step:
            delta_t = min_step * (1 if delta_t > 0 else -1)
            do_break = True
        else:
            delta_t = tmp
        # break when the adjustment was needed (not sure why)
        if do_break:
            if not rk45_next(rast3d_region, velocity_obj, point, next_point, delta_t, error):
                return False
            break

    return True


def rk45_next(rast3d_region, velocity_obj, point, next_point, delta_t, error):
    """The actual method which integrates new point of a flowline
    using Runge-Kutta method with Cash-Karp coefficients. Provides error estimate.

    :param dict rast3d_region: information about region
    :param velocity_obj: velocity class to handle different sources of velocity field
    :param list point: list of x, y, z coordinates from where to integrate
    :param list next_point: list of x, y, z coordinates of the integrated point
    (computed here and the list is modified)
    :param float delta_t: time step for integration
    :param list error: error vector (here modified)
    :return: True if next point is found, False if not (out of region/3D raster)
    """
    # 3 is 3 dimensions, 6 is the number of k's
    tmp1 = [[0] * 3] * 6
    tmp_point = [0] * 3
    velocity = get_velocity(rast3d_region, velocity_obj, point)
    if velocity is None:
        return False
    tmp1[0] = velocity
    # compute k's
    for i in range(1, 6):
        for j in range(3):  # for each coordinate
            sum_tmp = 0
            for k in range(i):  # k1; k1, k2; ...
                sum_tmp += B[i - 1][k] * tmp1[k][j]

            tmp_point[j] = point[j] + delta_t * sum_tmp

        velocity = get_velocity(rast3d_region, velocity_obj, tmp_point)
        if velocity is None:
            return False
        tmp1[i] = velocity

    # compute next point
    for j in range(3):
        sum_tmp = 0
        for i in range(6):
            sum_tmp += C[i] * tmp1[i][j]
        next_point[j] = point[j] + delta_t * sum_tmp

    if not rast3d_is_valid_location(rast3d_region, next_point[1], next_point[0], next_point[2]):
        return False

    # compute error vector
    for j in range(3):
        sum_tmp = 0
        for i in range(6):
            sum_tmp += DC[i] * tmp1[i][j]
        error[j] = delta_t * sum_tmp

    return True


def get_time_step(unit, step, velocity, cell_size):
    if unit == 'time':
        return step
    elif unit == 'length':
        return step / velocity
    elif unit == 'cell':
        return (step * cell_size) / velocity
