/* File: mdimarray.h
 * -----------------
 *  Implementation of a multidimensional array class that has the same style
 *  of use a as a native one-dimensional C array. 
 *
 *  Usage:
 *  {
 *  Array<int, 2> data = NewArray<int,2>(3,4);
 *  try{
 *    for(int i=0,v=0; i<3; ++i){
 *      for(int j=0; j<4; ++j, ++v){ data[i][j] = v; }
 *    }
 *  }catch(std::string& e) {
 *     cerr << "Exception Caught: " << e << endl;
 *     cin.get();
 *  }
 *  DelArray(data);}
 * 
 * A standard dynamic multidimensional array involves creating an array of ptrs to an
 * array of ptrs ... or using a vector<vector...>> 
 *  
 * The problems with those approaches are:
 * 1. space inefficiency because they store ptrs to ptrs (cost = size of matrix w/o last dim) 
 *  1a. in case of vector<>, also because it stores size of each vector
 * 2. require multiple indirections
 * 3. bad locality properties 
 * 4. are annoying to construct and deconstruct (although you can write helpers)
 * 5. lack easy bounds checking that can be turned off using a flag
 * 
 * They are convenient to use though, because you can easily access slices when
 * you provide indexes for earlier dimensions.
 *
 * The array class below provides the same ease of use, without the
 * inefficiency. Because of templates, it effectively performs loop unrolling.
 *
 * The implementation below stores a dimensionality array, a stride array and
 * the elms array. The only extra space is storing the stride array, which
 * should speed up indexing. (We should try a version without it to compare
 * speeds)
 *
 * created: 10/15/2007 author: Varun Ganapathi
 */

// comment out to turn off bounds checking
//#define BOUNDS_CHECK 1

#include <iostream>
#include <cassert>
using namespace std;

typedef unsigned int uint;

/* stringify anything that has ostream operator<< defined */
#include <sstream>
template<typename T>
string str(const T& x){
  ostringstream oss; oss << x; return oss.str();
}

/*
 * sets strides in row-major order, so stride[i] is the distance
 * to move to increment value of dimension i.
 * e.g. dims = [2,3,4] strides = [12,4,1] (strictly, we don't need the last stride)
 */
template<int K>
static void ComputeStrides(const uint* dims, uint* strides) {
  // compute strides
  strides[K-1]=1; for(int i=K-2; i>=0; --i) strides[i]=dims[i+1]*strides[i+1];
}

/*
 * Struct: ArrayBase
 * -----------------
 * Elements common to array of any dimensionality
 */
template<typename T>
struct ArrayBase {
  const uint* dims_;
  T* elms_;
  ArrayBase<T>( const uint* dims, T* elms):dims_(dims), elms_(elms){}
  inline void boundsCheck(uint i, int K){
#ifdef BOUNDS_CHECK
    if(i>=*dims_) {
      ostringstream oss;
      oss << "Operator[]::Index Out Of Bounds: " << i << "!<" << *dims_;
      oss << "K=" << K << "dims=[";
      for(int i=0; i<K; ++i) { oss << " " << dims_[i]; }
      oss << "]";
      abort();
      throw oss.str();
    }
#endif
  }
  uint dims(uint i) const { return dims_[i]; }
};

/*
 * Struct: Array<T,K>
 * -------------
 * K-D array with K > 1.
 * Does not take ownership of passed in arrays
 * Copies,Slices are shallow and simply point into passed in arrays
 * operator[] returns a slice into this array of lower dimensionality
 */
template<typename T, int K>
struct Array : public ArrayBase<T> {
  const uint* strides_;
  /* shallow copies in strides */
  Array<T,K>(const uint* dims, const uint* strides, T* elms):
    ArrayBase<T>(dims,elms), strides_(strides) {}

  /* return Array of lower dimensionality */
  inline Array<T,K-1> operator[](uint i) {
    this->boundsCheck(i,K); 
    return Array<T,K-1>(this->dims_+1, strides_+1, this->elms_ + i*(*strides_)); 
  }
  uint size() const { return this->dims(0)*strides_[0]; }
};

/*
 * Struct: Array<T,1>
 * -------------
 * 1-D array. Same as K dimensional except lacks stride ptr.
 * operator[] returns actual value of element
 */
template<typename T>
struct Array<T,1> : ArrayBase<T>{
public:
  Array(const uint* dims, T* elms) : ArrayBase<T>(dims,elms) {}
  Array(const uint* dims, const uint* strides, T* elms) : ArrayBase<T>(dims,elms) {}
  /* return value of item */
  T& operator[](uint i){ this->boundsCheck(i,1); return this->elms_[i]; }
  inline uint size() const { return *(this->dims_); }
  inline T* begin() const { return this->elms_; }
  inline T* end() const { return this->elms_+size(); }
};


#include <stdarg.h>
/*
 * Allocates the dims, strides and elms arrays for a multidimensional array
 */
template<typename T, int K>
Array<T,K> NewArray( unsigned int dim0, ...) {
  uint* dims = new unsigned int[2*K]; // keep strides in second half
  dims[0] = dim0;
  va_list VL;
  va_start(VL, dim0);
  for(int i=1; i<K; ++i){ dims[i]=va_arg(VL,unsigned int); }
  va_end(VL);
  ComputeStrides<K>(dims,dims+K);
  // allocate space for elements
  // dims[K] = stride of first dimension = prod of remaining dimensions
  // so we multiply by first dimension to get total size
  T* elms = new T[dims[K]*dims[0]];
  //for(int i=0; i<2*K; ++i) cerr << dims[i] << " "; cerr << endl;
  return Array<T,K>(dims,dims+K,elms);
}

/*
 * Warning: only all this on result of NewArray
 * otherwise anything could happen....
 */
template<typename T, int K>
void DelArray( Array<T,K>& x ) {
  delete[] x.dims_;
  delete[] x.elms_;
  x.dims_ = NULL;
  //x.strides_ = NULL;
  x.elms_ = NULL;
}

/*
 * Functions for reading and writing an array from a file.
 * It's equivalent to construction, so this is why I call it 
 * NewArray
 */
template<typename T, int K>
Array<T,K> NewArray( istream& is ) {
  // simple sanity-check
  int Kcheck;
  is.read((char*)&Kcheck, sizeof(int));
  assert(Kcheck == K);

  // read dims
  uint* dims = new unsigned int[2*K];
  is.read((char*)dims, K*sizeof(uint));

  // set up memory and metadata
  ComputeStrides<K>(dims,dims+K);
  uint size = dims[K]*dims[0];
  T* elms = new T[size];
  
  // read in data
  is.read((char*)elms, sizeof(T)*size);

  return Array<T,K>(dims,dims+K,elms);
}

template<typename T, int K>
void WriteArray( Array<T,K> A, ostream& os) {
  // write sanity check
  int Kcheck = K;
  os.write((char*)&Kcheck, sizeof(int));

  // write dims
  os.write((char*)A.dims_, sizeof(uint)*K);
  
  // write data
  os.write((char*)A.elms_, sizeof(T)*A.size());
}
