#!/usr/bin/env python3

import argparse
import itertools
import os
import sqlite3
import sys


class BinaryReader:
    def __init__(self, data):
        self.data = data
        self.pos = 0

    def read_bytes(self, count):
        assert self.pos + count <= len(self.data)

        bytes = self.data[self.pos:self.pos + count]
        self.pos += count

        return bytes

    def read_int32(self):
        return int.from_bytes(self.read_bytes(4), 'little')

    def read_cstr(self, block_size=None, encoding='ascii'):
        block = self.read_bytes(block_size)

        end = block.find(0)
        assert end != -1

        cstr = block[:end].decode(encoding)

        return cstr

    def finished(self):
        return self.pos == len(self.data)


MAGIC = 0xd731337
OP_NAME_BLOCK_SIZE = 20

def read_dt_file(file_name):
    history = []
    with open(file_name, "rb") as dt_data:
        reader = BinaryReader(dt_data.read())

        assert reader.read_int32() == MAGIC

        while not reader.finished():
            enabled = reader.read_int32()
            operation = reader.read_cstr(OP_NAME_BLOCK_SIZE)
            modver = reader.read_int32()
            op_params_size = reader.read_int32()
            assert op_params_size >= 0
            op_params = reader.read_bytes(op_params_size)

            history.append({
                'enabled': enabled,
                'operation': operation,
                'module': modver,
                'op_params': op_params
            })

    return history

def file_name_to_roll_and_name(filename):
    abspath = os.path.abspath(filename)
    roll, dt_name = os.path.split(abspath)
    assert dt_name.endswith('.dt')
    return roll, dt_name[:-3]

def get_image_id(conn, file_name):
    roll, base_name = file_name_to_roll_and_name(file_name)
    with conn:
        c = conn.cursor()
        c.execute("""select img.id from film_rolls roll, images img
                     where img.film_id = roll.id and roll.folder=? and img.filename=?""",
            (roll, base_name))
        results = c.fetchall()
        assert len(results) <= 1
        return results[0][0] if len(results) else None

def dest_name_seq():
    yield ''
    for seq_number in itertools.count():
        yield '{:02}'.format(seq_number + 1)

def rename_xmp_for_file(file_name):
    assert file_name.endswith(".dt")
    xmp_file_name = file_name[:-3] + '.xmp'
    if os.path.exists(xmp_file_name):
        for suffix in dest_name_seq():
            dest_name = xmp_file_name + '.pre-dt{}'.format(suffix)
            if not os.path.exists(dest_name):
                os.rename(xmp_file_name, dest_name)
                break

def rename_dt_file(file_name):
    assert file_name.endswith(".dt")
    os.rename(file_name, file_name + '.applied')

def replace_image_history(conn, image_id, new_history):
    insert_data = [(image_id, pos, entry['module'], entry['operation'],
                    entry['op_params'], entry['enabled'])
                   for pos, entry in enumerate(new_history)]
    with conn:
        conn.execute('delete from history where imgid = ?', (image_id,))

        conn.executemany(
            '''insert into history(imgid, num, module, operation,
                                   op_params, enabled, blendop_params, blendop_version,
                                   multi_priority, multi_name)
               values (?, ?, ?, ?, ?, ?, '', 7, 0, '')''',
            insert_data)

        conn.execute('update images set history_end=? where id=?;',
                     (len(new_history), image_id))

def main():
    parser = argparse.ArgumentParser(description="Read a .dt file into the library database.")
    parser.add_argument('library_db', metavar='LIBRARY_FILE', type=str,
                        help="Path to Darktable's sqlite3 library db")
    parser.add_argument('dt_file', metavar='DT_FILE', type=str,
                        help="DT file")

    args = parser.parse_args()

    conn = sqlite3.connect(args.library_db)
    try:
        image_id = get_image_id(conn, args.dt_file)
        if not image_id:
            print("Image for file {} not found in the database".format(args.dt_file),
                  file=sys.stderr)
            return 1
        history = read_dt_file(args.dt_file)
        replace_image_history(conn, image_id, history)

        # from pprint import pprint
        # pprint(image_id)
        # pprint(history)

        rename_xmp_for_file(args.dt_file)
        rename_dt_file(args.dt_file)

        print("Applied {} steps from file {} (image id {})".format(len(history),
                                                                   args.dt_file,
                                                                   image_id))
    finally:
        conn.close()

    return 0


if __name__ == "__main__":
    sys.exit(main())
