//This header file puts some C++ wrappers around the GSL vector and
//matrix structures plus providing a C++ interface to some pertinent
//BLAS and general linear algebra routines.

#ifndef __LINALG_HPP__
#define __LINALG_HPP__

#include "error.hpp"

#include <gsl/gsl_matrix.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_permutation.h>


namespace linalg{
  using namespace error_handling;

  class slice; //Vector slices corresponding to GNU Octave ranges
               //(e.g. [2:2:8] = [2, 4, 6, 8]).
  class vector;
  class vector_view;
  class LUmatrix;

  class matrix{
  public:
    //Constructors
    matrix(); //Allocate 1x1 matrix zero matrix.
    matrix(const size_t m, const size_t n, const double fillvalue = 0) ;
    matrix(gsl_matrix *M);

    matrix(const matrix& M);
    matrix operator=(const matrix& M);
    
    //Destructors
    ~matrix();

    //Number of decimal digits to output.
    size_t precision() const;
    static void set_precision(size_t p);

    //Number of rows and cols
    size_t rows() const;
    size_t cols() const;

    //Fortran-style parenthetical indexing (hence Octave-style too).
    //Indices start from 1.
    double& operator()(const size_t i, const size_t j) ;
    const double operator()(const size_t i, const size_t j)  const;

    //For vectorisation!
    vector_view operator()(const size_t  i, const slice &b);
    const vector_view operator()(const size_t  i, const slice &b) const;

    vector_view operator()(const slice &a, const size_t  j);
    const vector_view operator()(const slice &a, const size_t  j) const;

    //Arithmetic operations
    matrix operator*(const double  a) const; //Scale the matrix.
    matrix operator+(const matrix& N) const;
    matrix operator*(const matrix& N) const;
    matrix operator-(const matrix& N) const;
    vector operator*(const vector& v) const; //Mv, where v is treated
					     //as a column vector. 

    //More complex operations.
    LUmatrix* LU() const;    //LU decomposition, pivots in U.
    matrix inv()   const;    //Inverse.
    matrix T()     const;    //Tranpose
    double tr()    const;    //Trace
    double det()   const;    //Determinant.

    vector inv(const vector& w) const; //Solves Mv = w for v with LU
				       //factorisation. 

    double cond() const; //L2 condition number, using svd.

    friend class vector_view;
  private:
    gsl_matrix * A;
    static size_t precsn;
    mutable LUmatrix* LUptr;
    mutable bool LUfactored;

    mutable double condition_number;  //L2 condition number, obtained by svd.
    mutable gsl_vector* SVD;          //Matrix's singular values.

    void SVDfactor() const;
    mutable bool SVDfactored;
  };

  class LUmatrix{ 
    //A matrix that is factorised in LU form. L contains
    //ones in the diagonal (not stored) and U contains
    //pivots in its diagonal.
  public:
    gsl_matrix* matrix_ptr();
    gsl_permutation* perm_ptr();
    int sgn();
    int* sgn_ptr();

    LUmatrix(gsl_matrix* M);
    LUmatrix(const LUmatrix& LU);
    LUmatrix operator=(const LUmatrix& LU);
    ~LUmatrix();
  private: 
    gsl_permutation* p;
    gsl_matrix* A;
    int signum;
    LUmatrix();
  };

  class vector{
  public: 
    //Constructor
    vector(); //Allocate zero vector of size one.
    vector(const size_t n, const double fillvalue = 0);
    vector(gsl_vector *y);
    vector(const gsl_vector *y);

    vector(const vector& y);
    vector& operator=(const vector &y);

    //Destructor
    ~vector();

    //Number of decimal digits to output.
    size_t precision() const;
    static void set_precision(size_t p);

    //Number of elements
    size_t size() const;

    //Fortran-style parenthetical indexing (hence Octave-style too).
    //Indices start at 1.
    double& operator()(const size_t i) ;
    const double operator()(const size_t i) const ;
  
    //For vectorisation!
    vector_view operator()(const slice &v);
    const vector_view operator()(const slice &v) const;

    //Arithmetic operations:
    vector operator*(const double a)  const; //Scale the vector.
    vector operator+(const vector& w) const;
    vector operator-(const vector& w) const;
    double operator*(const vector& w) const; //Dot product.
    vector operator*(const matrix& M) const; //Computes vM where v is
					     //treated as a row vector.

    //Comparison
    bool operator==(const vector& w) const;
    bool operator<(const vector& w) const; //Lexicographical order,
					   //used for putting into STL
					   //sets or maps. Also,
					   //vectors of smaller
					   //dimension are smaller.

    //More complex operations
    double norm() const; //Euclidean norm.
    
    friend class vector_view;
    friend class matrix;
  private:
    gsl_vector * x;
    static size_t precsn;
  };



  class vector_view : public vector{
    friend class vector;
    friend class matrix;
  public:
    ~vector_view();
    vector_view& operator=(const vector& w);
    vector_view& operator=(const vector_view& w);
  private:
    vector_view();
    vector_view(gsl_vector* y, const slice& s);
    vector_view(gsl_matrix* A, const slice& a, const size_t j);
    vector_view(gsl_matrix* A, const size_t i, const slice& b);
  };


  class slice{
  public:
    slice();
    slice(size_t a, size_t b, size_t k=1); 

    //For setting the slice parameters anew.
    //Indices start from 1.
    slice operator()(size_t a, size_t b, size_t k=1);
    slice set(size_t a, size_t b, size_t k=1) ;

    size_t begin()const {return beg;};
    size_t end() const{return fin;};
    size_t stride() const{return str;};

  private:
    size_t beg,fin;
    size_t str;    
  };
  
  typedef vector point; //Useful alias.
}

namespace linalg{ //Non-member functions.
  
  // I/O
  std::ostream& operator<<(std::ostream& os, const vector &v);
  vector operator>>(std::istream& is, vector& v);
  std::ostream& operator<<(std::ostream& os, const matrix &M);
  matrix operator>>(std::istream& is, matrix& v);

  //Some arithmetic functions for comfortable syntax.
  vector operator*(double a, const vector& v);
  double norm(const vector& v);
  matrix operator*(double a, const matrix& M);
  matrix inv(const matrix& A);
  matrix T(const matrix& A);          
  double tr(const matrix& A);         
  double det(matrix& A);        
  double cond(matrix& A);

}

#endif //__LINALG_HPP__
