//*******************************************************************
// Linear time sorting algorithm
//*******************************************************************

#include <iostream>
// #include <stdlib>

#define SEPARATOR 0

typedef unsigned long ulong;

ulong m; // size of input
int logm; // logarithm of m
ulong n; // numbers in input
ulong *A; // array storing the input
ulong *B; // array storing the output

// statistics
ulong counter; // counts number of execution steps


class bignum;

class bignum
{
public:
  ulong start;   // start index in A
  ulong length;  // length in ulong words
  
  bignum *next;

  bignum(ulong s, ulong l)
  {
    start=s;
    length=l;
    next=NULL;
  }
};


class numgroup
{
public:
  ulong count;
  bignum *numlist;
  
  numgroup()
  {
    count=0;
    numlist=NULL;
  }
};

numgroup **group; // dynamic array of node lists



void generateInput()
{
  int i;
  ulong j;
  
  for (i=0; i<m; i++) A[i]=rand()%4+1;
  for (i=0; i<n-1; i++) 
  {
    do { j=(rand()*rand())%(m-2)+1; }
    while (A[j-1]==SEPARATOR || A[j]==SEPARATOR || A[j+1]==SEPARATOR);
    A[j]=SEPARATOR;
  }

  i=0;
  std::cout << "Input:\n";
  while (i<m)
  {
    std::cout << "( ";
    while (i<m && A[i]!=SEPARATOR)
    {
      std::cout << A[i] << " ";
      i++;
    }
    std::cout << ")\n";
    i++;
  } 
  std::cout << "\n";
}
  
  
void groupNodes()
{
  int i, j, k, size;
  ulong start, length;
  bignum *z;
  
  std::cout << "Grouping the nodes...\n";
  logm = 1;
  size = 1;
  while (size<m) { size<<=1; logm++; 
  std::cout << size << " " << logm << "\n"; }
  group = new numgroup*[logm];
  for (i=0; i<logm; i++) group[i]= new numgroup();
  
  for (i=0; i<m; i++)
  {
    start=i;
    j=-1;
    length=0;
    size=1;
    while (i<m && A[i]!=SEPARATOR)
    {
      i++;
      length++;
      if (length==size) { size<<=1; j++; }
    }
    z = new bignum(start, length);
    z->next = group[j]->numlist;
    group[j]->numlist = z;
    group[j]->count++;
  }
  
  for (i=0; i<logm; i++)
  {
    std::cout << "Group " << i << ": " << group[i]->count << " numbers\n";
    z = group[i]->numlist;
    while (z!=NULL)
    {
      std::cout << "( ";
      for (j=0; j<z->length; j++)
        std::cout << A[z->start+j] << " ";
      std::cout << ")\n";
      z=z->next;
    }
  }
}

//----------------------------------------------------------------

class pnode;

class pnode
{
public:
  bignum *num;  // number
  ulong dim;    // dimension
  ulong digit;  // value at dimension
  pnode *prev;  // smaller number
  pnode *next;  // larger number
  pnode *up;    // higher dimension
  pnode *down;  // lower dimension
  
  int state; // 0: first visit, 1: second, 2: third

  pnode(bignum *z, ulong d, ulong dig)
  {
    num = z;
    dim=d;
    digit=dig;
    prev=NULL;
    next=NULL;
    up=NULL;
    down=NULL;
    state = 0;
  }
};
  

void insertnext(pnode *p, bignum *z, ulong ndig)
{
  pnode *np;
  
  np = new pnode(z, p->dim, ndig);
  np->next = p->next;
  p->next = np;
  np->prev = p;  
  if (np->next!=NULL) np->next->prev = np;
}


void insertprev(pnode *p, bignum *z, ulong ndig)
{
  pnode *np;
  
  np = new pnode(z, p->dim, ndig);
  np->prev = p->prev;
  p->prev = np;
  np->next = p;
  if (np->prev!=NULL) np->prev->next = np;    
}


void insertdown(pnode *p, bignum *z, ulong dim, ulong odig, ulong ndig)
{
  pnode *np;
  
  np = new pnode(p->num, dim, odig);
  np->down = p->down;
  p->down = np;
  np->up = p;
  if (np->down!=NULL) np->down->up = np;
  
  p = np;
  if (odig <= ndig) insertnext(p, z, ndig);
               else insertprev(p, z, ndig);
}


void Patriciasort(numgroup *g, int level)
{
  bignum *oldz, *newz;
  ulong olddig, newdig;
  pnode *ptree, *pcurrent, *pold;
  ulong dim, digit;

  std::cout << "Patriciasort on group " << level << "\n";
  if (g->count<2) return;  // nothing to do

  // build Patricia tree
  oldz=g->numlist;
  ptree = new pnode(oldz, 1<<(level+1), 0);
  newz = oldz->next;
  while (newz!=NULL)
  {
    pcurrent = ptree;
    oldz = ptree->num;
    if (oldz->length < newz->length) dim = newz->length;
                                else dim = oldz->length;    
    if (ptree->down!=NULL && ptree->down->dim >= dim) dim = ptree->down->dim+1;
    while (dim>0) 
    {
      if (oldz->length>=dim) olddig = A[oldz->start+oldz->length-dim]; else olddig = 0;
      if (newz->length>=dim) newdig = A[newz->start+newz->length-dim]; else newdig = 0;
      if (olddig == newdig) 
      {
        dim--;
        if (pcurrent->down!=NULL && dim<=pcurrent->down->dim) pcurrent=pcurrent->down;
      }  
      else
      {
        if (pcurrent->dim > dim) 
        { 
          insertdown(pcurrent, newz, dim, olddig, newdig); 
          break; 
        }
        else
        {        
          if (olddig < newdig) 
          {
            if (pcurrent->next != NULL && pcurrent->next->digit <= newdig)
              { pcurrent=pcurrent->next; oldz = pcurrent->num; }
            else
              { insertnext(pcurrent, newz, newdig); break; }
          }
          else
          {
            if (pcurrent->prev != NULL && pcurrent->prev->digit >= newdig)
              { pcurrent=pcurrent->prev; oldz = pcurrent->num; }
            else
              { insertprev(pcurrent, newz, newdig); break; }
          }
        }
      }
    }    
    if (dim==0) insertdown(pcurrent, newz, 0, 0, 0);
    newz = newz->next;
  }

  // traverse Patricia tree through Euler tour
  pcurrent = ptree->down;
  oldz = NULL;
  while (pcurrent!=ptree)
  {
    pold = pcurrent;
    if (pcurrent->state == 0) // to prev
    {
      pcurrent->state++;
      if (pcurrent->prev!=NULL) 
      {
        pcurrent = pcurrent->prev;
        if (pcurrent->state > 0) delete(pold);  // will never be visited again      
      }
    }
    if (pcurrent->state == 1) // down
    {
      pcurrent->state++;
      if (pcurrent->down!=NULL) pcurrent = pcurrent->down;
      else
      {
        newz = pcurrent->num;
        newz->next = NULL;      
        if (oldz!=NULL) oldz->next = newz;
                   else g->numlist = newz;
        oldz = newz;
      }      
    }
    if (pcurrent->state == 2) // to next
    {
      pcurrent->state++;
      if (pcurrent->next!=NULL) 
      { 
        pcurrent = pcurrent->next;
        if (pcurrent->state==0) pcurrent->state++; // leave prev for last
                           else delete(pold);  // will never be visited again      
      }
    }
    if (pcurrent->state == 3) // up
    {
      pold = pcurrent;
      if (pcurrent->up != NULL) pcurrent = pcurrent->up;
      else
        if (pcurrent->prev != NULL) 
        { 
          pcurrent = pcurrent->prev;
          if (pcurrent->state > 0) delete(pold); // will never be visited again
        }
    }
  }
  delete(ptree);
}


//----------------------------------------------------------------

bignum **Rs, **Re;  // arrays for Radixsort

void Radixsort(numgroup *g, int level)
{
  ulong base, size, k, length;
  int i, j, logsize;
  bignum *number, *oldn;
  
  std::cout << "Radixsort on group " << level << "\n";
  if (g->count<2) return;
  base = 1<<(logm/5);
  // here we assume that logm/5 <= 16
  size=2; logsize=1;
  while (size<base) { size <<= logsize; logsize <<=1; }
  Rs = new bignum*[size];
  Re = new bignum*[size];
  for (k=0; k<size; k++) { Rs[k]=NULL; Re[k]=NULL; }
  length = 1<<(level+1); // maximum length of number
  
  // Bucketsort for all dimensions i
  for (i=1; i<=length; i++)
    for (j=1; j<32/logsize; j++)
    {
      number = g->numlist;
      while (number!=NULL)
      {
        if (i > number->length) k=0;
        else
          k = (A[number->start+number->length-i] >> ((j-1)*logsize)) & (size-1);
        if (Re[k]==NULL) 
          { Rs[k]=number; Re[k]=number; }
        else 
          { Re[k]->next = number; Re[k] = number; }
        oldn = number;
        number = number->next;
        oldn->next = NULL;
      }
      // merge lists
      number=NULL;
      for (k=0; k<size; k++)
      {
        if (Rs[k]!=NULL)
        {
          if (number==NULL) g->numlist = Rs[k];
                       else number->next = Rs[k];
          number = Re[k];
          Rs[k]=NULL; Re[k]=NULL;
        }
      }
    }
  delete(Rs);
  delete(Re);
}


//----------------------------------------------------------------
  
void combineGroups()
{
  int i, j;
  bignum *numlist, *current;

  // form sorted list  
  numlist = NULL;
  current = NULL;
  for (i=0; i<logm; i++)
  {
    if (group[i]->count>0)
    {
      if (numlist==NULL) numlist = group[i]->numlist;
                    else current->next = group[i]->numlist;
      current = group[i]->numlist;
      while (current->next!=NULL) current = current->next;
    }
  }

  // copy numbers to B  
  i=0;
  while (numlist!=NULL && i<m)
  {
    current = numlist;
    for (j=0; j<numlist->length; j++)
    {
      B[i] = A[numlist->start+j];
      i++;
    }
    if (i<m) { B[i]=SEPARATOR; i++; }
    numlist = numlist->next;
    delete(current);
  }
  delete(group);  
}  
  
  
//----------------------------------------------------------------
  
void PRMSort()
{
  int i;
  int limit;

  std::cout << "Sorting...\n\n";
  groupNodes();
  
  limit=1<<(logm/5);
  std::cout << "Limit: " << limit << "\n";
  for (i=0; i<logm; i++)
  {
    if ((1<<i) > group[i]->count) Patriciasort(group[i],i);
    else
      if (group[i]->count>limit) Radixsort(group[i],i);
      else Patriciasort(group[i],i);   
  }
  combineGroups();
}


void printOutput()
{
  int i;
     
  i=0;
  std::cout << "Output:\n";
  while (i<m)
  {
    std::cout << "( ";
    while (i<m && B[i]!=SEPARATOR)
    {
      std::cout << B[i] << " ";
      i++;
    }
    std::cout << ")\n";
    i++;
  } 
}
  

int main()
{
  int i;
  int runs;
  
  std::cout << "Input size: ";
  std::cin >> m;
  std::cout << "Number of numbers in input: ";
  std::cin >> n;
  std::cout << "Number of runs: ";
  std::cin >> runs;
  
  A = new ulong[m];  
  B = new ulong[m];
  counter = 0; 
  for (i=0; i<runs; i++)
  {
    generateInput();
    PRMSort();
    printOutput();
  }

  std::cout << "press [space] to continue \n";
  char c;
  do { std::cin.get(c); } while (std::cin.good() && c != ' ');

}

