/*One level up
*********************************************************************************
Subject: Solution of an m x n linear Diophantine system
         A*x = b using Hermite and Smith normal forms.
Author : Sjoerd.J.Schaper
Date   : 05-03-2007
Code   : ANSI C, C89
Links to QBasic and big integer-BASIC versions, plus a sample input file.
*********************************************************************************
This program is copyright (c) 2007 by the author. It is made available as is,
and no warranty - about the program, its performance, or its conformity to any
specification - is given or implied. It may be used, modified, and distributed
freely, as long as the original author is credited as part of the final work.
********************************************************************************/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>

#define NoZ(x) ((long) (x))               /* nonzero x? */

const int Verb = 1;                       /* print transition matrices V, U */
const double Mxd = 4503599627370496.;     /* full mantisse 2^52 */
int errsw, *pt, m, n;                     /* row pointers, rows & columns A */
double d[2][2], *ab, *b, *v, *a, *u;

double bezout (double *x, double *y)
{  /* Bezout's identity gcd(x,y) = ux + vy */
   int r, i = 0, j = 1;
   double q, c[3][2];

   c[0][i] = *x; c[0][j] = *y;
   c[1][i] =  1; c[1][j] =  0;
   c[2][i] =  0; c[2][j] =  1;

   while (NoZ(c[0][i])) {
      int t = i; i = j; j = t;
      q = floor(c[0][i] / c[0][j]);       /* Euclidean steps */
      for (r = 0; r < 3; r++)
         c[r][i] -= q * c[r][j];
   }

   if (c[0][j] < 0)
      for (r = 0; r < 3; r++)
         c[r][j] = -c[r][j];

  *x = c[1][j]; *y = c[2][j];
return c[0][j];
}

void Rmult (double *x, int rx, int s, int t)
{  /* right multiply mtx X*mtx D */
   double c[4]; int i, r;

   for (r = 0; r < rx; r++)               /* columns s and t */
   {
      int rs = pt[r] + s, rt = pt[r] + t;
      c[0] = x[rs] * d[0][0]; c[2] = x[rt] * d[1][0];
      c[1] = x[rs] * d[0][1]; c[3] = x[rt] * d[1][1];
      for (i = 0; i < 4; i++)
	 if (fabs(c[i]) > Mxd) errsw = 1;
      if (errsw) break;
      x[rs] = c[2] + c[0];
      x[rt] = c[3] - c[1];
   }
}

void Colel (int r, int s, int t)
{  /* column elimination step */
   double g;
   int rs = pt[r] + s, rt = pt[r] + t;

   d[0][0] = a[rt]; d[0][1] = a[rs];
   d[1][0] = a[rs]; d[1][1] = a[rt];

   g = bezout(&d[0][0], &d[1][0]);
   if (g > 1) {
      d[0][1] /= g; d[1][1] /= g; }

   Rmult(a, m, t, s);
   Rmult(u, n, t, s);
}

void Lmult (double *x, int sx, int r, int t)
{  /* left multiply mtx D*mtx X */
   int s;

   for (s = 0; s < sx; s++)               /* rows r and t */
   {
      int rs = pt[r] + s, ts = pt[t] + s;
   double c = d[0][1] * x[ts] + d[0][0] * x[rs];
      x[ts] = d[1][1] * x[ts] - d[1][0] * x[rs];
      x[rs] = c;
   }
}

void Rowel (int r, int s)
{  /* row elimination step */
   double g;
   int rr = pt[r] + r, sr = pt[s] + r;

   d[0][0] = a[rr]; d[0][1] = a[sr];
   d[1][0] = a[sr]; d[1][1] = a[rr];

   g = bezout(&d[0][0], &d[0][1]);
   if (g > 1) {
      d[1][0] /= g; d[1][1] /= g; }

   Lmult(a, n, r, s);
   Lmult(v, m, r, s);
}

void Negcol (double *x, int rx, int s)
{  /* negate column x[][s] */
   int r;

   for (r = 0; r < rx; r++)
      x[pt[r] + s] = -x[pt[r] + s];
}

void Redcol (int r, int t)
{  /* reduce mod A[r][t] */
   int i, s;
   double d = a[pt[r] + t];

   for (s = 0; s < t; s++)                /* reduce left block */
   {
      double q = floor(a[pt[r] + s] / d);

      for (i = 0; i < m; i++)
         a[pt[i] + s] -= q * a[pt[i] + t];
      for (i = 0; i < n; i++)
         u[pt[i] + s] -= q * u[pt[i] + t];
   }
}

int dchck (int i, int k)
{  /* check divisibility A[i][i] */
   int r, s, t;
   double d = a[pt[i] + i];

   for (r = i + 1; r < m; r++)            /* scan lower right block */
      for (s = i + 1; s < k; s++)
      {
	 if (NoZ(fmod(a[pt[r] + s], d))) { /* nonzero remainder */
	    for (t = 0; t < k; t++)       /* add rows i and r */
	       a[pt[i] + t] += a[pt[r] + t];
	    for (t = 0; t < m; t++)
	       v[pt[i] + t] += v[pt[r] + t];
	    return 1;
	 }
      }
return 0;
}

int HNF (int i, int j)
{  /* Hermite normal form A, return rank(A) */
   int r, s, t = 0;

   for (r = 0; r < i; r++)
   {
      for (s = t + 1; s < j; s++)
	 if (NoZ(a[pt[r] + s]))           /* row reduction */
	    Colel(r, s, t);
      if (errsw) return 1;

      if (a[pt[r] + t] < 0) {             /* column At = -At */
         Negcol(a, m, t);
         Negcol(u, n, t);
      }

      if (NoZ(a[pt[r] + t]))              /* final reductions */
	 Redcol(r, t++);
   }
return t;
}

double SNF (int i, int j)
{  /* Smith normal form square A, return |Det(A)| */
   int r, s, sw; double q = 1;

   for (r = 0; r < i; r++)
   {
      do {
         for (s = r + 1; s < j; s++)
	    if (NoZ(a[pt[r] + s]))        /* row reduction */
	       Colel(r, s, r);
	 if (errsw) return 1;

         for (sw = 0, s = r + 1; s < j; s++)
	    if (NoZ(a[pt[s] + r])) {      /* column reduction */
	       Rowel(r, s); sw = 1; }

	 if (!sw) sw = dchck(r, i);       /* check the rest of A */
      } while (sw);
      q *= a[pt[r] + r];
   }
return q;
}

void MultC (double *x, int rx, int sx, double *y)
{  /* let y:= mtx X*col y */
   int r, s;
   double *c = calloc(rx, sizeof(double));

   for (r = 0; r < rx; r++)
      for (s = 0; s < sx; s++)            /* multiply */
	 c[r] += x[pt[r] + s] * y[s];

   for (r = 0; r < rx; r++) y[r] = c[r];  /* copy back */

   free(c);
}

void Swap (double *r, double *s)
{  /* exchange r and s */

   double t = *r; *r = *s; *s = t;
}

void Redsol (double *x, int k)
{  /* find short solution x */
   int t, r, s, i, j;
   double p, z, xs = 0;

   for (s = 0; s < n - k; s++)
      for (r = 0; r < n; r++)
      {
         int rs = pt[r] + s;
	 Swap(&u[rs], &u[rs + k]);        /* left-align kernel */

	 if (s==0) xs += x[r] * x[r];     /* sizeof(solution x) */
      }

   do {
      for (i = -1, s = 0; s < n - k; s++)
         for (r = 0; r < n; r++)
         {
	    int rs = pt[r] + s;
            if (NoZ(u[rs])) {
	       double q = floor(x[r] / u[rs] + .5);

	       if (NoZ(q)) {              /* trial reductions */
	          for (z = 0, t = 0; t < n; t++)
                  {
                     p = x[t] - q * u[pt[t] + s];
		     z += p * p;
                  }
	          if (z < xs) { xs = z; i = r; j = s; }
               }
            }
	 }

      if (i > -1) {
	 double q = floor(x[i] / u[pt[i] + j] + .5);
	 for (r = 0; r < n; r++)          /* reduce solution */
	    x[r] -= q * u[pt[r] + j];
      }
   } while (i > -1);
}

int chcks (double *x, int k)
{  /* check if A*x = b */
   int r, s;

   for (r = 0; r < n; r++)
      for (s = 0; s < n - k; s++)         /* x:= kernel + x */
         x[r] += u[pt[r] + s];

   MultC(ab, m, n, x); s = 0;
   for (r = 0; r < m; r++)
      if (NoZ(ab[pt[r] + n] - b[r])) {
	 printf("fail: b(%d) = %.f\n", r + 1, b[r]);
	 s = 1;
      }
return s;
}

void PrntC (char *g, double *x, int rx)
{  /* print column x[] */
   double p = 1; int r, w;

   for (r = 0; r < rx; r++)
      if (fabs(x[r]) > p)
	 p = fabs(x[r]);                  /* largest entry in column */
      else if (!NoZ(x[r]))
	 x[r] = 0;                        /* tidy up */

   for (w = 2; NoZ(p); p = floor(p / 10), w++);

   printf("%s", g);
   for (r = 0; r < rx; r++)
      printf("% *.f\n", w, x[r]);
}

void PrntM (char *g, double *x, int rx, int sx)
{  /* print matrix X[][] */

   double *p = malloc(sx * sizeof(double));
   int r, s, *w = malloc(sx * sizeof(int));

   for (s = 0; s < sx; s++)
      for (p[s] = 1, r = 0; r < rx; r++)
      {
	 int rs = pt[r] + s;
	 if (fabs(x[rs]) > p[s])
	    p[s] = fabs(x[rs]);
         else if (!NoZ(x[rs]))
	    x[rs] = 0;
      }

   for (s = 0; s < sx; s++)
      for (w[s]=2; NoZ(p[s]); p[s]=floor(p[s]/10), w[s]++);

   printf("%s", g);
   for (r = 0; r < rx; r++)
   {
      for (s = 0; s < sx; s++)
	 printf("% *.f", w[s], x[pt[r] + s]);
      printf("\n");
   }

   free(w); free(p);
}

void Inpts (int m, int n)
{  /* input linear system */
   char gS[72], *el; int r, s;

   for (r = 0; r < m; r++)	          /* input A*x = b */
   {
      printf("\r row A%d and b%d ", r + 1, r + 1);
      gets(gS); strcat(gS, "# ");

      el = strtok(gS, "| ");              /* valid separators */
      for (s = 0; s < n + 1; s++)
      {
         if (el==NULL) {
            b[r] = 0; break;
         }
         b[r] = atof(el);
         if (s < n) {
            ab[pt[r] + s] = b[r];
            a[pt[r] + s] = b[r];
         }
         el = strtok(NULL, "| ");
      }
      ab[pt[r] + n] = b[r];
   }

   /* identity matrices */
   for (r = 0; r < m; v[pt[r] + r] = 1, r++);
   for (r = 0; r < n; u[pt[r] + r] = 1, r++);
}

int main (void)
{
   double q; int k, r, sw;
   char gS[72]; clock_t tim;

   for (tim = clock();;)
   {
      errsw = 0;
      printf("\n rows ");
      do gets(gS);
      while (strpbrk(gS, "'"));
      m = atoi(gS);
      if (m < 1) break;
      printf(" cols "); gets(gS);
      n = atoi(gS);
      if (n < 1) {
	 for (r = 0; r < m; r++) {
	    printf("?"); gets(gS); }
	 continue;
      }

      k = n + 1; if (m > k) k = m;
      ab = calloc(m * k, sizeof(double));
      v = calloc(m * k, sizeof(double));
      a = calloc(m * k, sizeof(double));
      u = calloc(n * k, sizeof(double));
      b = malloc(k * sizeof(double));
      pt = malloc(k * sizeof(int));

      for (r = 0; r < k; r++)
	 pt[r] = r * k;                   /* initialize row pointers */

      Inpts(m, n);
      PrntM("\n A:\n", a, m, n);
      PrntC(" b:\n", b, m);

      k = HNF(m, n);                      /* Hermite normal form */
      if (k < 1) goto jump;
      if (errsw) {
	 printf("overflow\n"); goto jump; }
      PrntM(" HNF:\n", a, m, n);
      
      q = SNF(k, m);                      /* Smith normal form */
      if (errsw) {
	 printf("overflow\n"); goto jump; }
      printf(" SNF:  d(L) = %.f", q);
      PrntM("\n", a, m, n);
      if (Verb) {
         PrntM(" V:\n", v, m, m);         /* unimodular manipulations */
         PrntM(" U:\n", u, n, n);
      }

      for (r = 0; r < m; r++) {
         sw = b[r] != 0; if (sw) break;
      } 

      MultC(v, k, m, b);                  /* c = V*input vector b */
      /* PrntC(" c:\n", b, k); */

      for (r = 0; r < k; r++)             /* solve system SNF*y = c */
      {
	 q = a[pt[r] + r];	          /* divide by invariant factors */
	 if (NoZ(fmod(b[r], q))) {        /* nonzero remainder */
	    printf("inconsistent\n"); goto jump; }
	 b[r] /= q;
      }

      MultC(u, n, k, b);                  /* solution x = U*y */
      if (k < n) Redsol(b, k);            /* rank < n */

      if (sw) PrntC(" x:\n", b, n);
      if (k < n) PrntM(" nullspace:\n", u, n, n - k);

      if (chcks(b, k)) printf(" rank = %d\n", k);

   jump:;
      free(pt); free(b); free(u);
      free(a); free(v); free(ab); printf("\n");
   }
   printf("\nTimer: %f s\n", (double)(clock() - tim) / CLOCKS_PER_SEC);
}
/* */