#!/usr/local/bin/python
# -*- coding: utf-8 -*-

# Gorazd Generator
# Generator of dictionary entries from ALTO XML.
# Copyright (C) 2018  Vít Tuček, Slovanský ústav AV ČR, v. v. i.

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

from multiprocessing import Pool
from itertools import repeat

import os.path as osp
import glob
import sys

from lxml import etree as et

import pandas as pd
import numpy as np

from sklearn.cluster import KMeans

# import matplotlib.pyplot as plt
# import matplotlib.image as mpimg
# from matplotlib.patches import Rectangle

from PIL import Image, ImageDraw

ns = {'alto': 'http://www.loc.gov/standards/alto/ns-v3#'}


def get_scale_factors(xml, img):
    page = xml.xpath(".//alto:Page", namespaces=ns)[0]
    height = float(page.get("HEIGHT"))
    width = float(page.get("WIDTH"))
    return height / img.size[1], width / img.size[0]


def put_line(draw, vpos, hpos, width, sx, sy):
    path = list(map(int, (hpos / sx, vpos / sy, (hpos + width) / sx, vpos / sy)))
    # print(path)
    draw.line(path, fill='red')


def get_sizes(xml):
    ATTRIBUTES = ['HPOS', 'WIDTH', 'VPOS', 'HEIGHT']

    def element2list(e):
        res = []
        for a in ATTRIBUTES:
            res.append(int(e.get(a, 0)))
        return res

    def textline2list(tl):
        res = element2list(tl)
        strings = pd.DataFrame(list(map(element2list,
                                        tl.xpath(".//alto:String", namespaces=ns)
                                        )
                                    ),
                               columns=list(map(str.lower, ATTRIBUTES))
                               )
        res.extend([strings['vpos'].max(),
                    strings['vpos'].mean(),
                    strings['vpos'].min(),
                    strings['height'].max(),
                    strings['height'].mean(),
                    strings['height'].min(),
                    ])
        return res

    tl_columns = list(map(str.lower, ATTRIBUTES)) + \
                 ['sv_max', 'sv_mean', 'sv_min', 'sh_max', 'sh_mean', 'sh_min']

    df = pd.DataFrame(list(map(textline2list,
                               xml.xpath(".//alto:PrintSpace//alto:TextLine",
                                         namespaces=ns))),
                      columns=tl_columns)

    df['ls_o'] = df['vpos'].diff() - df['height']
    df['ls_n'] = df['sv_max'].diff() - df['height']
    df['ls'] = df.sv_min - df.shift(1).sv_max - df.sh_max

    # dff = df[df['ls'] > -1000] # filter out page breaks
    return df


def insert_separators(xml_file, img_file, output_file):
    print(xml_file)
    xml = et.parse(xml_file)
    img = Image.open(img_file)

    sx, sy = get_scale_factors(xml, img)

    dff = get_sizes(xml)
    est = KMeans(n_clusters=2)
    est.fit(np.matrix(dff[dff.ls > -1000].ls).T)  # get rid of pagebreaks

    cc = est.cluster_centers_
    print("Clusters are at:\n %s" % cc)

    center = (cc.max() - cc.min()) / 2 + cc.min()
    print("With center: %s" % center)

    t = 0.33
    split = max(70, cc.min() * t + (1 - t) * cc.max())  # TODO tune minimal threshold
    print("Split: %s" % split)

    draw = ImageDraw.Draw(img)
    nr = 0
    for row in dff[dff.ls > split].iterrows():
        i = row[0]  # row[0] is index
        r = row[1]  # row[1] is everything else in the row
        vpos = int(r.vpos - dff.loc[i].ls / 2)
        # put_line(draw, vpos, hpos, width, sx, sy)
        try:
            if dff.loc[i].hpos - dff.loc[i - 1].hpos <= 60:  # TODO tune this
                put_line(draw, vpos, r.hpos, r.width, sx, sy)
                nr += 1
        except:
            print("EXCEPTION")
            pass
    print("Inserted %s record separators" % nr)
    del (draw)
    img.save(osp.join(output_file, osp.basename(img_file)))

    print()
    return dff, est


# dff, est = insert_separators("in/100/SINE_DIL_1__070QY3BEX0001P.xml",
#                  "in/100/SINE_DIL_1__070QY3BN30001P.JPG",
#                  "out/test.jpg")

xml_files = sorted(glob.glob("in/100/*.xml"))
jpg_files = sorted(glob.glob("in/100/*.jpg"))

# for xml, jpg in zip(xml_files, jpg_files):
#    insert_separators(xml, jpg, "out")

if __name__ == '__main__':
    with Pool(4) as p:
        p.starmap(insert_separators, zip(xml_files, jpg_files, repeat('out')))
    sys.stdout.flush()