#!/usr/bin/env python

############################################################################
#
# MODULE:       tiles.rast.stats
# AUTHOR(S):    Markus Neteler
#
# PURPOSE:      Calculates univariate statistics from a raster map based
#               on vector tiles.
# COPYRIGHT:    (C) 2016 by the GRASS Development Team
#
#               This program is free software under the GNU General Public
#               License (>=v2). Read the file COPYING that comes with GRASS
#               for details.
#
#############################################################################

#%module
#% description: Calculates univariate statistics from a raster map based on vector tiles.
#% keyword: vector
#% keyword: statistics
#% keyword: raster
#% keyword: univariate statistics
#% keyword: zonal statistics
#% keyword: sampling
#% keyword: querying
#% overwrite: yes
#%end
#%flag
#% key: l
#% label: Link tiles instead of import
#% description: Requires datasource with multiple access capabilities (typically PostGIS). Also output_layer option is ignored since result is written directly to input layer
#%end
#%option
#% key: dsn
#% description: Data source name
#%end
#%option G_OPT_V_LAYER
#% key: layer
#% description: Name of input layer
#% required: yes
#%end
#%option G_OPT_DB_COLUMN
#% key: keycolumn
#% description: Name of key column used for input
#%end
#%option G_OPT_V_LAYER
#% key: output_layer
#% description: Name for output layer
#%end
#%option G_OPT_R_INPUT
#% key: raster
#% description: Name of input raster map to calculate statistics from
#%end
#%option G_OPT_V_MAP
#% key: mask
#% description: Name of input vector map to used as mask (otherwise current region bbox is used)
#% required: no
#%end
#%option
#% key: column_prefix
#% type: string
#% description: Column prefix for new attribute columns (default: map name)
#% required : no
#%end
#%option
#% key: method
#% type: string
#% description: The methods to use
#% required: no
#% multiple: yes
#% options: number,minimum,maximum,range,average,stddev,variance,coeff_var,sum,first_quartile,median,third_quartile,percentile
#% answer: number,minimum,maximum,range,average,stddev,variance,coeff_var,sum,first_quartile,median,third_quartile,percentile
#%end
#%option
#% key: percentile
#% type: integer
#% description: Percentile to calculate
#% options: 0-100
#% answer: 90
#% required : no
#%end
#%option
#% key: nprocs
#% description: Number of processes
#% answer: 1
#% type: integer
#%end
#%rules
#% required: output_layer,-l
#%end

import os
import sys
import atexit
import copy
import time
from subprocess import PIPE

import grass.script as grass
from grass.pygrass.modules import Module, MultiModule, ParallelModuleQueue
from grass.pygrass.gis.region import Region
from grass.exceptions import CalledModuleError
from osgeo import ogr

def cleanup():
    if basename:
        grass.message("Cleaning...")
        Module('g.remove', type='vector', pattern='{}*'.format(basename), flags='f', quiet=True)
    
def filter_tiles(dsn, layer_name, keycolumn, mask):
    ds = ogr.Open(dsn, True) # write
    if ds is None:
        grass.fatal("Unable to open '{}'".format(dsn))
    layer = ds.GetLayerByName(layer_name)
    if layer is None:
        grass.fatal("Unable to get <{}> layer".format(layer_name))

    if not mask:
        reg = Region()
        wkt = 'POLYGON(({w} {s}, {e} {s}, {e} {n}, {w} {n}, {w} {s}))'.format(n=reg.north, s=reg.south, e=reg.east, w=reg.west)
    else:
        try:
            wkt_module = Module('v.out.ascii', input=mask, format='wkt', stdout_ = PIPE)
            wkt = wkt_module.outputs.stdout.splitlines()[0] # TODO: only first feature is used
        except CalledModuleError:
            sys.exit(1)

    # spatial filter is too rough, it uses bboxes
    filter_geom = ogr.CreateGeometryFromWkt(wkt)
    layer.SetSpatialFilter(filter_geom)

    tiles = []
    try:
        for feature in layer:
            # improve filter
            geom = feature.GetGeometryRef()
            if geom.Intersects(filter_geom):
                try:
                    tiles.append(feature.GetField(keycolumn))
                except:
                    # TODO: test if keycolumn is really used as FID column by GDAL
                    tiles.append(feature.GetFID())
    except ValueError as e:
        grass.fatal("{}".format(e))
    grass.message("{} tiles filtered".format(len(tiles)))

    return ds, layer, tiles

def import_tiles(dsn, layer_name, keycolumn, mask):
    ds, layer, tiles = filter_tiles(dsn, layer_name, keycolumn, mask)
    
    # database for each map is required since we call v.rast.stats in paralell
    Module('db.connect', database='$GISDBASE/$LOCATION_NAME/$MAPSET/$MAP/sqlite.db')

    grass.message("Importing tiles...")
    start = time.time()
    import_module = Module('v.in.ogr', input=dsn, layer=layer_name, quiet=True, overwrite=True, run_=False)
    maps = []
    for tile in tiles:
        out_name = '{}_{}'.format(basename, tile)
        maps.append(out_name)
        new_import = copy.deepcopy(import_module)
        queue.put(new_import(where="{}={}".format(keycolumn, tile),
                             output=out_name))
    queue.wait()
    grass.message("... done in {:.1f} sec".format(time.time() - start))
    
    return maps, ds, layer

def link_tiles(dsn, layer_name, keycolumn, mask):
    ds, layer, tiles = filter_tiles(dsn, layer_name, keycolumn, mask)
    
    grass.message("Linking tiles...")
    start = time.time()
    import_module = Module('v.external', input=dsn, layer=layer_name, quiet=True, overwrite=True, run_=False)
    maps = []
    for tile in tiles:
        out_name = '{}_{}'.format(basename, tile)
        maps.append(out_name)
        new_import = copy.deepcopy(import_module)
        queue.put(new_import(where="{}={}".format(keycolumn, tile),
                             output=out_name))
    queue.wait()
    grass.message("... done in {:.1f} sec".format(time.time() - start))
    
    return maps, ds, layer

def perform_statistics(maps, raster, column_prefix, method, percentile):
    region_module =  Module("g.region", align=raster, run_=False)
    stats_module = Module('v.rast.stats', flags='c', raster=raster, column_prefix=column_prefix, method=method.split(','),
                          percentile=percentile, quiet=True, run_=False)
    grass.message("Performing statistics...")
    start = time.time()
    for map_name in maps:
        new_region = copy.deepcopy(region_module)
        new_stats = copy.deepcopy(stats_module)
        mm = MultiModule([new_region(vector=map_name),
                          new_stats(map=map_name)],
                         sync=False, set_temp_region=True)
        queue.put(mm)
    queue.wait()
    grass.message("... done in {:.1f} sec".format(time.time() - start))

def write_output(maps, ds, ilayer, column_prefix, methods, percentile, output):
    # create new layer
    olayer = ds.CreateLayer(output, ilayer.GetSpatialRef(),
                            ilayer.GetGeomType())
    if olayer is None:
        grass.fatal("Unable to create output layer")

    # copy attributes
    feat_defn = ilayer.GetLayerDefn()
    for i in range(feat_defn.GetFieldCount()):
        ifield = feat_defn.GetFieldDefn(i)
        ofield = ogr.FieldDefn(ifield.GetNameRef(), ifield.GetType())
        ofield.SetWidth(ifield.GetWidth())
        olayer.CreateField(ofield)

    # add stats attributes
    stats_cols = []
    for method in methods.split(','):
        fname = '{}_{}'.format(column_prefix, method)
        if method == 'percentile':
            fname += '_{}'.format(percentile)
        stats_cols.append(fname)
        ofield = ogr.FieldDefn(fname, ogr.OFTReal)
        olayer.CreateField(ofield)

    # copy features
    feat_defn = ilayer.GetLayerDefn()
    ofeat_defn = olayer.GetLayerDefn()

    ilayer.ResetReading() # features are already filtered
    idx = 0
    for feature in ilayer:
        ofeature = ogr.Feature(ofeat_defn)
        ofeature.SetGeometry(feature.GetGeometryRef().Clone())
        for i in range(0, feat_defn.GetFieldCount()):
            ofeature.SetField(feat_defn.GetFieldDefn(i).GetNameRef(), feature.GetField(i))

        # append feature stats
        stats = Module('v.db.select', map=maps[idx], columns=stats_cols, flags='c',
                       separator=';', stdout_=PIPE)
        fidx = feat_defn.GetFieldCount()
        for field in stats.outputs.stdout.rstrip(os.linesep).split(';'):
            if len(field) > 0:
                ofeature.SetField(ofeat_defn.GetFieldDefn(fidx).GetNameRef(), float(field))
            fidx += 1

        olayer.CreateFeature(ofeature)
        idx += 1
    
def main():
    if not grass.find_file(options['raster'], element='raster')['fullname']:
        grass.fatal("Raster map <{}> not found".format(options['raster']))

    start = time.time()

    if flags['l']:
        if options['output_layer']:
            grass.warning("Option <{}> will be ignored".format('output_layer'))

        # link tiles
        maps, ds, ilayer = link_tiles(options['dsn'], options['layer'], options['keycolumn'], options['mask'])
        
    else:
        # import tiles
        maps, ds, ilayer = import_tiles(options['dsn'], options['layer'], options['keycolumn'], options['mask'])

        # Remove output if it already exists
        if ds.GetLayerByName(options['output_layer']):
            if os.getenv('GRASS_OVERWRITE', '0') == '1':
                ds.DeleteLayer(options['output_layer'])
            else:
                grass.fatal("option <output>: <{}> exists. To overwrite, "
                            "use the --overwrite flag".format(options['output_layer']))

    column_prefix = options['raster'] if not options['column_prefix'] else options['column_prefix']

    # perform statistics
    perform_statistics(maps, options['raster'],
                       column_prefix,
                       options['method'],
                       options['percentile'])

    # write output
    if not flags['l']:
        write_output(maps, ds, ilayer, column_prefix, options['method'], options['percentile'],
                     options['output_layer'])

    grass.message("Done in {:.1f} sec".format(time.time() - start))

    # close datasource
    ds.Destroy()
    
if __name__ == "__main__":
    options, flags = grass.parser()

    basename = 'tiles_rast_stats_{}'.format(os.getpid())

    # queue for parallel jobs
    queue = ParallelModuleQueue(int(options['nprocs']))

    atexit.register(cleanup)
    sys.exit(main())
