#include "cs.h"
/* L = chol (A, [pinv parent cp]), pinv is optional */
csn *cs_chol (const cs *A, const css *S)
{
    double d, lki, *Lx, *x, *Cx ;
    int top, i, p, k, n, *Li, *Lp, *cp, *pinv, *s, *c, *parent, *Cp, *Ci ;
    cs *L, *C, *E ;
    csn *N ;
    if (!CS_CSC (A) || !S || !S->cp || !S->parent) return (NULL) ;
    n = A->n ;
    N = cs_calloc (1, sizeof (csn)) ;	    /* allocate result */
    c = cs_malloc (2*n, sizeof (int)) ;	    /* get int workspace */
    x = cs_malloc (n, sizeof (double)) ;    /* get double workspace */
    cp = S->cp ; pinv = S->pinv ; parent = S->parent ;
    C = pinv ? cs_symperm (A, pinv, 1) : ((cs *) A) ;
    E = pinv ? C : NULL ;	    /* E is alias for A, or a copy E=A(p,p) */
    if (!N || !c || !x || !C) return (cs_ndone (N, E, c, x, 0)) ;
    s = c + n ;
    Cp = C->p ; Ci = C->i ; Cx = C->x ;
    N->L = L = cs_spalloc (n, n, cp [n], 1, 0) ;    /* allocate result */
    if (!L) return (cs_ndone (N, E, c, x, 0)) ;
    Lp = L->p ; Li = L->i ; Lx = L->x ;
    for (k = 0 ; k < n ; k++) Lp [k] = c [k] = cp [k] ;
    for (k = 0 ; k < n ; k++)	    /* compute L(:,k) for L*L' = C */
    {
	/* --- Nonzero pattern of L(k,:) ------------------------------------ */
	top = cs_ereach (C, k, parent, s, c) ;	    /* find pattern of L(k,:) */
	x [k] = 0 ;				    /* x (0:k) is now zero */
	for (p = Cp [k] ; p < Cp [k+1] ; p++)	    /* x = full(triu(C(:,k))) */
	{
	    if (Ci [p] <= k) x [Ci [p]] = Cx [p] ;
	}
	d = x [k] ;			/* d = C(k,k) */
	x [k] = 0 ;			/* clear x for k+1st iteration */
	/* --- Triangular solve --------------------------------------------- */
	for ( ; top < n ; top++)    /* solve L(0:k-1,0:k-1) * x = C(:,k) */
	{
	    i = s [top] ;		/* s [top..n-1] is pattern of L(k,:) */
	    lki = x [i] / Lx [Lp [i]] ; /* L(k,i) = x (i) / L(i,i) */
	    x [i] = 0 ;			/* clear x for k+1st iteration */
	    for (p = Lp [i] + 1 ; p < c [i] ; p++)
	    {
		x [Li [p]] -= Lx [p] * lki ;
	    }
	    d -= lki * lki ;		/* d = d - L(k,i)*L(k,i) */
	    p = c [i]++ ;
	    Li [p] = k ;		/* store L(k,i) in column i */
	    Lx [p] = lki ;
	}
	/* --- Compute L(k,k) ----------------------------------------------- */
	if (d <= 0) return (cs_ndone (N, E, c, x, 0)) ; /* not pos def */
	p = c [k]++ ;
	Li [p] = k ;		    /* store L(k,k) = sqrt (d) in column k */
	Lx [p] = sqrt (d) ;
    }
    Lp [n] = cp [n] ;		    /* finalize L */
    return (cs_ndone (N, E, c, x, 1)) ; /* success: free E,s,x; return N */
}