#!/usr/bin/env python
"""
MODULE:    i.mlpy

AUTHOR(S): Stepan Turek <stepan.turek AT seznam.cz>

PURPOSE: Classifies segmented raster using mlpy library.

COPYRIGHT: (C) 2012 Stepan Turek, and by the GRASS Development Team

This program is free software under the GNU General Public License
(>=v2). Read the file COPYING that comes with GRASS for details.
"""

#%module
#% description: Classifies segmented raster.
#% keywords: classification
#% keywords: imaginery
#% keywords: segmentation
#%end

#%option G_OPT_I_GROUP
#%end

#%option G_OPT_R_INPUT
#% key: seg_raster
#% description: Segmented raster
#%end

#%option G_OPT_V_INPUT
#% key: training_points
#%end

#%option G_OPT_V_FIELD
#% key: trlayer
#% guidependency: class_column
#%end

#%option G_OPT_R_OUTPUT
#% description: Output classified raster
#%end

#%option G_OPT_R_OUTPUT
#% required: no
#% key: training_segments
#% description: Output raster with training segments
#%end

#%option G_OPT_DB_COLUMN
#% key: class_column
#% type: string
#% required: no
#%end

import os
import sys
import atexit

import grass.script as grass
from grass.pygrass.vector import Vector
from grass.pygrass import raster

import numpy as np
import mlpy

import string #TODO remove it

def cleanup():
   grass.run_command('g.remove', vect = 'i_mlpy_tmp_seg_vect', quiet = True)
   grass.run_command('g.remove', vect = 'i_mlpy_tmp_vect_tr', quiet = True)

def prepareTmpVector(seg_raster, temp_vect_seg, raster_maps, col_prefs):
    """!Create vector map from segmentation raster and 
        compute statistics to it's table for every raster in group."""
    if grass.run_command('r.to.vect',
                          quiet = True,
                          overwrite = True,
                          input = seg_raster,
                          output = temp_vect_seg,
                          column = 'segment_id',
                          type = 'area') != 0:
            grass.fatal(_('%s failed') % 'r.to.vect')
    
    for m in raster_maps:
        if grass.run_command('v.rast.stats',
                             quiet = True,
                             vector = temp_vect_seg,
                             raster = m,
                             flags = 'e',
                             column_prefix = col_prefs[m]) != 0:
            grass.fatal(_('%s failed') % 'v.rast.stats')

def prepareTrainingSegments(seg_raster, temp_vect_seg, tr_pts):
    """!Get classes from training points into correspondent polygons in segmentation vecor map."""
    seg_id_col = 'segment_id'

    tr_pts_map = tr_pts['map']
    tr_pts_cols = grass.vector_columns(tr_pts_map, 1).keys()

    if seg_id_col in tr_pts_cols:
        if grass.run_command('v.db.dropcol', 
                              map = tr_pts_map, 
                              columns = seg_id_col) != 0:
            grass.fatal(_("Error creating column <%s>") % seg_id_col)

    if grass.run_command('v.db.addcolumn', 
                          map = tr_pts_map, 
                          columns = ("%s INT" % seg_id_col)) != 0:
        grass.fatal(_("Error creating column <%s>") % seg_id_col)

    if grass.run_command('v.what.rast',
                          quiet = True,
                          map = tr_pts_map,
                          raster = seg_raster,
                          column = seg_id_col) != 0:
        grass.fatal(_('%s failed') % 'v.what.rast')

    if grass.run_command('v.db.join',
                          quiet = True,
                          map = temp_vect_seg,
                          otable = tr_pts_map.split('@')[0],
                          column = seg_id_col,
                          ocolumn = seg_id_col,
                          scolumns = tr_pts['class_col']
                          ) != 0:
        grass.fatal(_('%s failed') % 'v.db.join')
 
def getClassDataCols(col_prefs, used_stats):
    """!Get columns with data for classification."""

    class_data_cols = []
    for m, pref in  col_prefs.iteritems():
        for stat in used_stats:
            class_data_cols.append(pref + stat)
    return class_data_cols  

def classify(temp_vect_seg, class_data_cols, class_col):

    v_segs = Vector(temp_vect_seg)
    v_segs.open()
    link = v_segs.dblinks[1]

    table = link.table()

    # get training areas
    table.filters.select( 'segment_id' + ', ' + class_col + ', ' + \
                          ', '.join(class_data_cols))\
                          .where('%s IS NOT NULL' % class_col)\
                          .order_by('segment_id')
    cur = table.execute()

    tr_table = []
    tr_classes = []

    classes = [] # contains labels of classes, position in list is their id used in mlpy
    cur = table.execute()

    classify_results = {}
    tr_segs_classes = {}

    row = cur.fetchone()
    while row is not None:
        if row[1] not in classes:
            classes.append(row[1])

        seg_id = row[0]
        class_id = classes.index(row[1]) + 1

        tr_segs_classes[seg_id] = class_id

        tr_classes.append(class_id)
        tr_table.append(row[2:])
        classify_results[seg_id] = class_id

        row = cur.fetchone()

    class_alg = mlpy.LDAC()
    class_alg.learn(tr_table, tr_classes)

    table.filters.select( 'segment_id' + ', ' + \
                          ', '.join(class_data_cols))\
                          .where('%s IS NULL' % class_col)\
                          .order_by('segment_id')
    cur = table.execute()
    row = cur.fetchone()

    while row is not None:
        pred_class_id = class_alg.pred(row[1:])
        seg_id = row[0]
        classify_results[seg_id] = pred_class_id
        row = cur.fetchone()

    return classify_results, classes, tr_segs_classes

    #TODO use raster numpy
    seg = raster.RasterSegment(seg_raster)
    seg.open()

    out = raster.RasterSegment(output)
    out.remove()
    out.open('w', 'CELL')

    tr_seg = raster.RasterSegment(tr_segments)
    tr_seg.remove()
    tr_seg.open('w', 'CELL')

    for irow in xrange(seg.rows):
        for icol in xrange(seg.cols):
            out[irow, icol] = classify_results[seg[irow][icol]]
            try:   
                tr_seg[irow, icol] = tr_segs_classes[seg[irow][icol]]
            except KeyError:
                pass
    
    seg.close()
    out.close()
    tr_seg.close()

    rules = ''
    for id, cl in enumerate(classes):
        rules += ('%d:%s\n' % (id + 1, cl)) 

    temp_rules = grass.tempfile()    
    try:
        temp_rules_o = open(temp_rules, 'w')
        temp_rules_o.write(rules)
        temp_rules_o.close()
    except IOError:
        grass.fatal(_("Unable to write data into tempfile."))

    if grass.run_command('r.category', 
                         map = output,
                         rules = temp_rules) != 0:
        grass.fatal(_('%s failed') % 'r.category')


def main():
    output = options['output']
    seg_raster = options['seg_raster']
    tr_segments = options['training_segments']

    group = options['group']
    s = grass.read_command('i.group', flags='g', group = group, quiet = True)
    raster_maps = s.splitlines()
    if not raster_maps:
        grass.fatal(_('Group <%s> contains no rasters.') % group)

    col_prefs = {}
    for i, m in enumerate(raster_maps):
        col_prefs[m] = ('r%d_' % i)

    vect_tr_pts = options['training_points']
    trlayer = options['trlayer']

    temp_vect_tr_pts = 'i_mlpy_tmp_vect_tr'
    if grass.run_command('g.copy', 
                         vect = vect_tr_pts + ',' + temp_vect_tr_pts, 
                         overwrite = True) != 0:
        grass.fatal(_('%s failed') % 'g.copy')

    class_col = options['class_column']

    tr_pts = {'map' : temp_vect_tr_pts,
              'layer' : trlayer,
              'class_col' : class_col}

    temp_vect_seg = 'i_mlpy_tmp_seg_vect'

    prepareTmpVector(seg_raster, temp_vect_seg, raster_maps, col_prefs)
    prepareTrainingSegments(seg_raster, temp_vect_seg, tr_pts)

    used_stats = ['_mean', '_variance', '_median', '_third_quartile', '_first_quartile']
    class_data_cols = getClassDataCols(col_prefs, used_stats)
    
    classify(temp_vect_seg, class_data_cols, class_col)
    return 0

if __name__ == "__main__":
    options, flags = grass.parser()
    atexit.register(cleanup)
    main()
