#!/apps/prod/python/anaconda/anaconda3/bin/python


import sys, os
import logging, argparse
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
   
global options
options = { 
    "qmodel":    "bcc",\
    "lig_ff":    "gaff", \
    "resp_dir":  "",\
    "prepin_dir":".",\
    "debug":     0,\
    "resname":   "MOL",\
    "verbose":   3,\
    # sqm: if =1 use reduced sqm convergence criteria */
    "sqm":       0 \
}
def main():
    """
    """
    fn1 = 'BHo.mol2'
    fn2 = 'PHo1.mol2'
    
   
    # read in molecules
    mol1 = Chem.MolFromMol2File(fn1)
    mol2 = Chem.MolFromMol2File(fn2)
    edit_mol1 = Chem.EditableMol(mol1)
    edit_mol2 = Chem.EditableMol(mol2)
    
    # generate tmp mols which only contain the matching atoms
    del_atoms1 = []
    for atom in mol1.GetAtoms():
        name = atom.GetProp('_TriposAtomName')
        if not atom.IsInRing():
            del_atoms1.append( atom.GetIdx())
            
    del_atoms2 = []
    for atom in mol2.GetAtoms():
        name = atom.GetProp('_TriposAtomName')
        if not atom.IsInRing():
            del_atoms2.append( atom.GetIdx())   
            
    del_atoms1.sort(reverse=True)
    del_atoms2.sort(reverse=True)
   
    for atom in del_atoms1:
        edit_mol1.RemoveAtom(atom)
    for atom in del_atoms2:
        edit_mol2.RemoveAtom(atom)
 
    mol1 = edit_mol1.GetMol()
    mol2 = edit_mol2.GetMol()
    
    natom1 = mol1.GetNumAtoms()
        
    # find the matching atoms by calculating distance matrix between the two molecules
    mol12 = Chem.CombineMols(mol1,mol2)
    dmat =  Chem.Get3DDistanceMatrix(mol12)
    
    p1 = []
    p2 = []
    for row in dmat:
 
        (p1tmp,p2tmp) = row.argsort()[:2]
        p2tmp = p2tmp - natom1
        p1.append(p1tmp)
        p2.append(p2tmp)
 
    p1y = tuple(p1[0:natom1])
    p2y = tuple(p2[0:natom1])
    
    print ('\nfrom numpy argsort (pXy)')
    print (p1y)
    print (p2y)
    print(type(p1y))
    print(type(p2y))
    
    
    
    print ('\nfrom explicit list (pXx)')
    
    p1 = []
    p2 = []
    for i in range(0,6):
        j = i+1
        if j == 6: j= 0
        p1.append(i)
        p2.append(j)
  

    print(p1)
    print(p2) 
    
    p1x = tuple(p1[0:natom1])
    p2x = tuple(p2[0:natom1])
   
    print (p1x)
    print (p2x)
    print(type(p1x))
    print(type(p2x))

    print ('\nfrom explicit tuple (pXz)')

    p1z = (0,1,2,3,4,5)
    p2z = (1,2,3,4,5,0)
       
    print (p1z)
    print (p2z)
    print(type(p1z))
    print(type(p2z))
   
    z1 = list(zip(p1z,p2z))
    z2 = list(zip(p1x,p2x))
    z3 = list(zip(p1y,p2y))
   
    print ('\nzip from explizit tuple:')
    print (z1)
    print ('\nzip from list-> tuple:')
    print (z2)
    print ('\nzip from argsort -> tuple:')
    print (z3)

    print ('\n\nGetAlignmentTransform from explicit tuple:\n')
    rmsd,tmat = AllChem.GetAlignmentTransform(mol1,mol2, atomMap=z1)
    print(rmsd, tmat)
    
    print ('\n\nGetAlignmentTransform from explicit list:\n')
    rmsd,tmat = AllChem.GetAlignmentTransform(mol1,mol2, atomMap=z2)
    print(rmsd, tmat)
    
    print ('\n\nGetAlignmentTransform from numpy argsort:\n')
    rmsd,tmat = AllChem.GetAlignmentTransform(mol1,mol2, atomMap=z3)
    print(rmsd, tmat)
    sys.exit()
    
 
 
# end of main
    
if __name__ == "__main__":
    sys.exit(main())
    

