# -*- coding: utf-8 -*-
"""
@brief test trilinear interpolation by comparing to scipy implementation
@author: Anna Petrasova
"""
import unittest
from random import uniform
from grass.script import core as gcore
from grass.script import array as garray
from grass.script import raster3d as grast3d

from rast3d_functions import rast3d_trilinear_interpolation


class TestTrivariateInterpolation(unittest.TestCase):

    def setUp(self):
        from scipy import ndimage
        gcore.use_temp_region()
        self.name = 'random_voxel'
        self.minim = -25.
        self.maxim = 123.
        self.res = 5.

    def test_interpolation(self):
        from scipy import ndimage
        gcore.run_command('g.region', n=self.maxim, s=self.minim, e=self.maxim, w=self.minim,
                          t=self.maxim, b=self.minim, res3=self.res)
        gcore.run_command('r3.mapcalc', expression=self.name + ' = rand(-20,500)')
        rast3d = garray.array3d()
        rast3d.read(mapname=self.name)
        reg = grast3d.raster3d_info(self.name)

        for i in range(100):
            # avoid edges because numpy gives always 0
            point = [uniform(self.minim + reg['ewres'] / 2., self.maxim - reg['ewres'] / 2.),
                     uniform(self.minim + reg['nsres'] / 2., self.maxim - reg['nsres'] / 2.),
                     uniform(self.minim + reg['tbres'] / 2., self.maxim - reg['tbres'] / 2.)]
            tested = rast3d_trilinear_interpolation(reg, [rast3d, rast3d, rast3d],
                                                    point[1], point[0], point[2])
            coordinates = [[(point[2] - 0.5 * reg['tbres'] - self.minim) / reg['tbres']],
                          [(reg['north'] - point[1] - 0.5 * reg['nsres']) / reg['nsres']],
                          [(point[0] - 0.5 * reg['ewres'] - self.minim) / reg['ewres']]]
            reference = ndimage.map_coordinates(rast3d, order=1, coordinates=coordinates)
            self.assertAlmostEqual(tested[0], reference, places=4)
            self.assertAlmostEqual(tested[1], reference, places=4)

    def tearDown(self):
        gcore.run_command('g.remove', rast3d=self.name)
        gcore.del_temp_region()

if __name__ == '__main__':
    unittest.main()
