from sympy import floor, sqrt, Integer, sign, divisors, S

def descent(A, B, n=1):
    """
    Uses Lagrange's method to finda a solution to $w^2 = Ax^2 + By^2$ where
    $A \neq 0$ and $B \neq 0). Output a tuple $(x_0, y_0, z_0)$ which is a solution
    to the above equation. This solution is used as a base to calculate all the other solutions.

    """
    #print "n= ", n, ", A= ", A, ", B= ", B,
    
    if abs(A) > abs(B):
        #print
        y, x, w = descent(B, A, n+1)
        return (x, y, w)
    
    if A == 1:
        return (S.One, 0, S.One)
    
    if B == 1:
        return (0, S.One, S.One)
    
    start = 0
    while 1:
        r = quadratic_congruence(A, B, start)
        #print ", r= ", r
        if r is None:
            break
        start = r + 1
        Q = (r**2 - A) // B
        
        # This is where I edited to fix the case A = 4 and B = -7
        if Q == 0:
            B_0 = 1
            d = 0
        else:
            div = divisors(Q)
            B_0 = None

            for i in div:
                if isinstance(sqrt(abs(Q) // i), Integer):
                    B_0, d = sign(Q)*i, sqrt(abs(Q) // i)
                    break

        if B_0 != None:
            X, Y, W = descent(A, B_0, n+1)
            if X is None:
                break
            return ((r*X - W), Y*(B_0*d), (-A*X + r*W))
    
    return None, None, None


def quadratic_congruence(a, m, start):
    """
    Solves the quadratic congruence $x^2 \equiv a \ (mod \ m)$. Returns the
    first solution in the range $0 .. \lfloor k*m \rfloor$.
    Return None if solutions do not exist. Currently uses bruteforce.
    Good enough for ``m`` sufficiently small.

    TODO: An efficient algorithm should be implemented.
    """
    m = abs(m)
    
    for i in range(start, m // 2 + 1 if m%2 == 0 else m // 2 + 2):
        if (i**2 - a) % m == 0:
            return i
    
    return None


def test():
    u = [(5, 4), (13, 23), (3, -11), (41, -113), (4, -7)]
    for a, b in u:
        x, y, w = descent(a, b)
        assert a*x**2 + b*y**2 == w**2

    assert descent(234, -65601) != (None, None, None)

test()
