# -*- coding: utf-8 -*-
"""
@brief voxel traversal algorithm for precise flowaccumulation computing
@author: Anna Petrasova
"""

import numpy as np
import math
import unittest
from random import uniform


# based on Amanatides, John, and Andrew Woo. "A fast voxel traversal algorithm for ray tracing."
# Proceedings of EUROGRAPHICS. Vol. 87. 1987.
def traverse(start, end, region_info):
    """!Computes which voxel cells are traversed by a line.
    To each transected cell we add 1.

    :param start: start point (3 coordinates)
    :param end: end point (3 coordinates)
    :param dict region_info: information about region
    :return: coordinates of transected voxel cells from north west bottom
    """
    # initialize
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    dz = end[2] - start[2]

    stepX = 1 if start[0] < end[0] else -1
    stepY = 1 if start[1] < end[1] else -1
    stepZ = 1 if start[2] < end[2] else -1

    X = int(math.floor((start[0] - region_info['west']) / region_info['ewres']))
    Y = int(math.floor((start[1] - region_info['south']) / region_info['nsres']))
    Z = int(math.floor((start[2] - region_info['bottom']) / region_info['tbres']))
    X_end = int(math.floor((end[0] - region_info['west']) / region_info['ewres']))
    Y_end = int(math.floor((end[1] - region_info['south']) / region_info['nsres']))
    Z_end = int(math.floor((end[2] - region_info['bottom']) / region_info['tbres']))

    tDeltaX = abs(region_info['ewres'] / float(dx) if dx != 0 else 1e6)
    tDeltaY = abs(region_info['nsres'] / float(dy) if dy != 0 else 1e6)
    tDeltaZ = abs(region_info['tbres'] / float(dz) if dz != 0 else 1e6)

    xtmp = (start[0] - region_info['west']) / region_info['ewres']
    ytmp = (start[1] - region_info['south']) / region_info['nsres']
    ztmp = (start[2] - region_info['bottom']) / region_info['tbres']

    if stepX > 0:
        tMaxX = tDeltaX * (1.0 - (xtmp - math.floor(xtmp)))
    else:
        tMaxX = tDeltaX * (xtmp - math.floor(xtmp))
    if stepY > 0:
        tMaxY = tDeltaY * (1.0 - (ytmp - math.floor(ytmp)))
    else:
        tMaxY = tDeltaY * (ytmp - math.floor(ytmp))
    if stepZ > 0:
        tMaxZ = tDeltaZ * (1.0 - (ztmp - math.floor(ztmp)))
    else:
        tMaxZ = tDeltaZ * (ztmp - math.floor(ztmp))

    coordinates = []
    while True:
        if tMaxX < tMaxY:
            if tMaxX < tMaxZ:
                tMaxX = tMaxX + tDeltaX
                X = X + stepX

            else:
                tMaxZ = tMaxZ + tDeltaZ
                Z = Z + stepZ

        else:
            if tMaxY < tMaxZ:
                tMaxY = tMaxY + tDeltaY
                Y = Y + stepY

            else:
                tMaxZ = tMaxZ + tDeltaZ
                Z = Z + stepZ

        if (X, Y, Z) == (X_end, Y_end, Z_end) or \
           (stepX * (X - X_end) > 0 or
            stepY * (Y - Y_end) > 0 or
            stepZ * (Z - Z_end) > 0):  # just to make sure it breaks
            break
        coordinates.append((X, region_info['rows'] - Y - 1, Z))
    return coordinates


class TestTraverseVoxel(unittest.TestCase):

    def test_traverse(self):
        for i in range(100):
            region_info = {'rows': int(uniform(1, 50)), 'cols': int(uniform(1, 50)), 'depths': int(uniform(1, 50)),
                           'nsres': uniform(0.5, 10), 'ewres': uniform(0.5, 10), 'tbres': uniform(0.5, 10),
                           'south': uniform(-100, 100), 'west': uniform(-100, 100), 'bottom': uniform(-100, 100)}

            start = [(uniform(region_info['west'], region_info['west'] + region_info['cols'] * region_info['ewres'])),
                     (uniform(region_info['south'], region_info['south'] + region_info['rows'] * region_info['nsres'])),
                     (uniform(region_info['bottom'], region_info['bottom'] + region_info['depths'] * region_info['tbres']))]
            end = [(uniform(region_info['west'], region_info['west'] + region_info['cols'] * region_info['ewres'])),
                   (uniform(region_info['south'], region_info['south'] + region_info['rows'] * region_info['nsres'])),
                   (uniform(region_info['bottom'], region_info['bottom'] + region_info['depths'] * region_info['tbres']))]

            coords = traverse(start, end, region_info)

            differ = np.array(end) - np.array(start)
            t = np.linspace(0, 1, 10000)
            x = start[0] + differ[0] * t
            y = start[1] + differ[1] * t
            z = start[2] + differ[2] * t

            X = np.floor((x - region_info['west']) / region_info['ewres']).astype(int)
            Y = np.floor((y - region_info['south']) / region_info['nsres']).astype(int)
            Y = region_info['rows'] - Y - 1
            Z = np.floor((z - region_info['bottom']) / region_info['tbres']).astype(int)
            coords2 = [tuple(row) for row in np.column_stack((X, Y, Z))]

            seen = set()
            # we don't use unique because we want to keep the order
            unique = [k for k in coords2 if k not in seen and not seen.add(k)][1:-1]
#            message = "start: {start}\nend: {end}\nregion info: {info}".format(start=start,
#                                                                               end=end, info=region_info)
            # we accept 'First list contains 1 additional elements.' because comparing is not as precise
            self.assertListEqual(coords, unique)


if __name__ == '__main__':
    unittest.main()
