'''
Module Name:  Meanshift pseudo code

Created on May 24, 2016

@author: Bo Yang, Moritz Lennert, Markus Metz

GRASS GSoc2016 project, All Rights Reserved
'''

#==================================================================================
#               testing the parameter of type them in
#==================================================================================
import numpy as np
import math


coverg_count = 0
threshold_factor= 0.01
spectral_difference_threshold = 5


#===================================================================================
#                       input fine Raster
#===================================================================================


input_raster= RasterToNumPyArray(Raster(inputRaster))#A1
last_iteration_raster = np.asarray([[input_raster[j][i] for i in range(input_raster.shape[1])] for j in range(input_raster.shape[0])])
iteration_indicator_raster = np.asarray([[0 for i in range(input_raster.shape[1])] for j in range(input_raster.shape[0])])
mNewRaster = np.asarray([[0.0 for i in range(input_raster.shape[1])] for j in range(input_raster.shape[0])])


pixel_number = (input_raster.shape[0]-windowSize) * (input_raster.shape[1]-windowSize)

#===================================================================================
#							Mean shift
#===================================================================================


def Mean_shift_filter(Raster=last_iteration_raster, w = windowSize):
	for iti in range(50):
		for j in range(w/2,input_raster.shape[0]-w/2):
        #if (j % (Raster.shape[0]/10) == 0):
        #arcpy.AddMessage(str(int(j*1.0/Raster.shape[0]*100)) + "% ")
			for i in range(w/2,input_raster.shape[1]-w/2):
				if iteration_indicator_raster[j][i] != 0:
					continue
				totalWeight=0.0
				for m in range(w):
					for n in range(w):
						if abs(Raster[j-w/2+m][i-w/2+n] - Raster[j][i]) < spectral_difference_threshold:
							totalWeight = totalWeight + math.exp(-0.5 * (math.sqrt((((m-w/2)**2)+((n-w/2)**2)))/(w/2))**2)
						# pre-calculate the total weight 
				for m1 in range(w):
					for n1 in range(w):
						if abs(Raster[j-w/2+m1][i-w/2+n1] - Raster[j][i]) < spectral_difference_threshold:
							mNewRaster[j][i] = mNewRaster[j][i] + Raster[j-w/2+m1][i-w/2+n1] * (math.exp(-0.5 * (math.sqrt((((m1-w/2)**2)+((n1-w/2)**2)))/(w/2))**2)/totalWeight)
				interm_shift = abs(mNewRaster[j][i] -  last_iteration_raster[j][i])
				if interm_shift <threshold_factor:
					iteration_indicator_raster[j][i] = 1
				last_iteration_raster[j][i] = mNewRaster[j][i]
				mNewRaster[j][i] = 0.0

                    

Mean_shift_filter()
saveRaster()









