'''This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.'''
    
import os, sys
import processing as st
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (QgsField, QgsFeature, QgsPointXY, QgsProcessing,QgsWkbTypes, QgsGeometry, QgsProcessingAlgorithm, QgsProcessingParameterFeatureSource, QgsProcessingParameterFeatureSink,QgsFeatureSink,QgsFeatureRequest,QgsFields)

class Branches_Nodes(QgsProcessingAlgorithm):

    Network='Network'
    Sample_Area='Sample_Area'
    IB = 'Interpretation Boundary'
    Branches='Branches'
    Nodes='Nodes'
    
    def __init__(self):
        super().__init__()
        
    def name(self):
        return "Branches and Nodes"

    def tr(self, text):
        return QCoreApplication.translate("Branches_Nodes", text)

    def displayName(self):
        return self.tr("Branches and Nodes")
 
    def group(self):
        return self.tr("NetworkGT")
    
    def shortHelpString(self):
        return self.tr("Create branches and nodes of a fracture network")

    def groupId(self):
        return "Topology"
    
    def helpUrl(self):
        return "https://github.com/BjornNyberg/NetworkGT"
    
    def createInstance(self):
        return type(self)()
    
    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterFeatureSource(
            self.Network,
            self.tr("Network"),
            [QgsProcessing.TypeVectorLine]))
        self.addParameter(QgsProcessingParameterFeatureSource(
            self.Sample_Area,
            self.tr("Sample_Area"),
            [QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterFeatureSource(
            self.IB,
            self.tr("Interpretation Boundary"),
            [QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterFeatureSink(
            self.Branches,
            self.tr("Branches"),
            QgsProcessing.TypeVectorLine))
        self.addParameter(QgsProcessingParameterFeatureSink(
            self.Nodes,
            self.tr("Nodes"),
            QgsProcessing.TypeVectorPoint))
    
    def processAlgorithm(self, parameters, context, feedback):
        
        layer = self.parameterAsSource(parameters, self.Network, context)
        Sample_Area = self.parameterAsSource(parameters, self.Sample_Area, context)
        Mask_Outline = self.parameterAsSource(parameters, self.Sample_Area, context)
        
        fields = QgsFields()
        fields.append(QgsField("Class", QVariant.String))
        fields.append(QgsField("Connection", QVariant.String))
        fields.append(QgsField("B_Weight", QVariant.Double))
        fields.append(QgsField("Sample_No_", QVariant.Int))

        (writer, dest_id) = self.parameterAsSink(parameters, self.Branches, context,
                                               fields, QgsWkbTypes.LineString, layer.sourceCrs())
        
        fields2 = QgsFields()
        fields2.append(QgsField("Class", QVariant.String))
        fields2.append(QgsField("Sample_No_", QVariant.Int))

        (writer2, dest_id2) = self.parameterAsSink(parameters, self.Nodes, context,
                                           fields2, QgsWkbTypes.Point, layer.sourceCrs())
                                           
        parameters = {'INPUT':parameters[self.Network],'LINES':parameters[self.Network],'OUTPUT':'memory:'}
        unknown_nodes,point_data = [],[]  
        c_points = {}   
        
        templines = st.run('native:splitwithlines',parameters)
        features = templines['OUTPUT'].getFeatures(QgsFeatureRequest())
        Graph = {} #Store all node connections
        cursorm = [feature.geometry() for feature in Mask_Outline.getFeatures(QgsFeatureRequest())]

        for feature in features:
            try:      
                geom = feature.geometry().asPolyline()
                start,end = geom[0],geom[-1]
                startx,starty=start
                endx,endy=end
                branch = [(round(startx,10),round(starty,10)),(round(endx,10),round(endy,10))]          
                for b in branch:
                    if b in Graph: #node count
                        Graph[b] += 1
                    else:
                        Graph[b] = 1  
                for m in cursorm:
                    geom = feature.geometry()
                    if not geom.within(m): #Branches
                        geom = geom.intersection(m)
                        parts = []
                        if QgsWkbTypes.isSingleType(geom.wkbType()):
                            parts.append(geom)
                        else:
                            for part in geom: #Check for multipart polyline
                                parts.append(QgsGeometry.fromPolyline(part)) #intersected geometry
                        for inter in parts:
                            if inter.length() != 0.0: #Branches
                                geom = inter.asPolyline()
                                istart,iend = geom[0],geom[-1]
                                istartx,istarty=istart
                                iendx,iendy=iend
                                inter_branch = [(istartx,istarty),(iendx,iendy)]  
                                for (x,y) in inter_branch: #Points
                                    testPoint = QgsGeometry.fromPointXY(QgsPointXY(x,y))
                                    if not testPoint.within(m.buffer(-0.0001,5)): #Test if point is on edge of interpretation boundary
                                        unknown_nodes.append((x,y))
                
            except Exception as e:
                feedback.pushInfo(QCoreApplication.translate('Interpretation Boundary','test%s'%(e)))
                
        cursorm = [(feature.geometry(),feature.id()) for feature in Sample_Area.getFeatures(QgsFeatureRequest())]
        
        features = templines['OUTPUT'].getFeatures(QgsFeatureRequest())
        for feature in features:
            try:      
                geom = feature.geometry().asPolyline()
                start,end = geom[0],geom[-1]
                startx,starty=start
                endx,endy=end
                branch = [(round(startx,10),round(starty,10)),(round(endx,10),round(endy,10))]    
                name = []      
                for (x,y) in branch:  
                    if (x,y) in unknown_nodes:
                        V = 'U'
                    else:
                        if (x,y) in Graph:
                            node_count = Graph[(x,y)]
                            if node_count == 1:
                                V = 'I'
                            elif node_count == 3:
                                V = 'Y'
                            elif node_count == 4:
                                V = 'X'
                            else:
                                V = 'Error'
                        else:
                            V = 'Error'
                    name.append(V)
                Class = " - ".join(sorted(name[:2])) #Organize the order of names
                name = Class.replace('X','C').replace('Y','C')
                name = name.split(" - ")
                Connection = " - ".join(sorted(name))
                
                for m in cursorm:
                    geom = feature.geometry()
                    if geom.within(m[0]): #Branches
                        weight = 1
                        for (x,y) in branch:  #Points
                            testPoint = QgsGeometry.fromPointXY(QgsPointXY(x,y))
                            if not testPoint.within(m[0].buffer(-0.001,2)): #Test if point is on edge of sample area
                                V = 'E'
                                weight -= 0.5
                            elif (x,y) in unknown_nodes:
                                V = 'U'
                                weight -= 0.5
                            else:
                                if (x,y) in Graph:
                                    node_count = Graph[(x,y)]
                                    if node_count == 1:
                                        V = 'I'
                                    elif node_count == 3:
                                        V = 'Y'
                                    elif node_count == 4:
                                        V = 'X'
                                    else:
                                        V = 'Error'
                                else:
                                    V = 'Error'	
                            if m[1] in c_points:
                                if (x,y) not in c_points[m[1]]:
                                    data2 = [V,m[1],QgsPointXY(x,y)]
                                    c_points[m[1]].append((x,y))
                                    point_data.append(data2)
                            else:
                                data2 = [V,m[1],QgsPointXY(x,y)]
                                c_points[m[1]]=[(x,y)]
                                point_data.append(data2)         
                        data = [Class,Connection,weight,m[1]]
                        fet = QgsFeature(fields) 
                        fet.setGeometry(feature.geometry())
                        fet.setAttributes(data)
                        writer.addFeature(fet,QgsFeatureSink.FastInsert)              

                    elif geom.intersects(m[0]):
                        
                        geom = geom.intersection(m[0])
                        
                        parts = []
                        if QgsWkbTypes.isSingleType(geom.wkbType()):
                            parts.append(geom)
                        else:
                            for part in geom: #Check for multipart polyline
                                parts.append(QgsGeometry.fromPolyline(part)) #intersected geometry
                        for inter in parts:
                            if inter.length() != 0.0: #Branches
                                geom = inter.asPolyline()
                                istart,iend = geom[0],geom[-1]
                                istartx,istarty=istart
                                iendx,iendy=iend
                                inter_branch = [(istartx,istarty),(iendx,iendy)]  
                                weight = 1
                                for (x,y) in inter_branch: #Points
                                    rx,ry = round(x,10),round(y,10)   
                                    V = 'E'           
                                    testPoint = QgsGeometry.fromPointXY(QgsPointXY(x,y))
                                    if testPoint.within(m[0]):
                                        if (rx,ry) in unknown_nodes:
                                            V = 'U'
                                        else:
                                            if (rx,ry) in Graph:
                                                node_count = Graph[(rx,ry)]
                                                if node_count == 1:
                                                    V = 'I'
                                                elif node_count == 3:
                                                    V = 'Y'
                                                elif node_count == 4:
                                                    V = 'X'
                                                else:
                                                    V = 'Error'
                                            else:
                                                V = 'Error'	  
                                    if m[1] in c_points:
                                        if (rx,ry) not in c_points[m[1]]:
                                            data2 = [V,m[1],QgsPointXY(x,y)]
                                            c_points[m[1]].append((rx,ry))
                                            point_data.append(data2)
                                    else:
                                        c_points[m[1]]=[(rx,ry)]
                                        data2 = [V,m[1],QgsPointXY(x,y)]
                                        point_data.append(data2)
                                    if V == 'E' or V == 'U':
                                        weight -= 0.5
                                
                                data = [Class,Connection,weight,m[1]]
                                fet = QgsFeature(fields)   
                                fet.setGeometry(inter)
                                fet.setAttributes(data)
                                writer.addFeature(fet,QgsFeatureSink.FastInsert)
            except Exception as e:
                feedback.pushInfo(QCoreApplication.translate('Sample_Area','%s'%(e)))
                    
        for data2 in point_data:
            fet = QgsFeature(fields)   
            fet.setGeometry(QgsGeometry.fromPointXY(data2[2]))
            fet.setAttributes(data2[:-1])
            writer2.addFeature(fet,QgsFeatureSink.FastInsert) 

        return {self.Branches:dest_id,self.Nodes:dest_id2}