/* -*- C++ -*- */


#include "libprofile.h"





static const char* sd3_outfile_name = "sd3.info";

typedef  std::tr1::unordered_map < int, int > SD3_MEMOP_HASH;

SD3_MEMOP_HASH  sd3_memop_hash;

static struct dep_hash_entry *sd3hashdep=NULL;

static int max_hash_size = 0;

const int stride_tolerate = 10;

FILE *depfp;
FILE *errfp;


static void sd3_point_check_dependence(lentry *loop, 
                                        MEMORY_OPERATION_LIST &stmts, MEMORY_OPERATION &new_stmt);

/* Record dependence a depends on b */
bool 
sd3_record_dependence( const DEPENDENCY &new_dep )
{
	if (new_dep.type() == dep_RAR )
		return false;

	if (new_dep.type() == dep_NO )
		return false;

	struct dep_hash_entry *entry;

	HASH_FIND_INT (sd3hashdep, &(new_dep._a), entry );

	if (entry) 
	{
		DEPENDENCY_VEC &dep_set = entry->dependencies;
		if ( !dep_set.Find(new_dep) )
		{
			new_dep.print(depname);
			dep_set.push_back(new_dep);
      return true;
		}		
    return false;
	}
	else
	{	
		entry = new dep_hash_entry;
		entry->memop_id= new_dep.a();
		entry->dependencies.push_back (new_dep);
		HASH_ADD_INT (sd3hashdep, memop_id, entry );	
		new_dep.print(depname);
    return true;
	}
}






static inline ptrdiff_t
sd3_gcd (ptrdiff_t a, ptrdiff_t b)
{
  ptrdiff_t x, y, z;

  x = abs (a);
  y = abs (b);

  while (x > 0)
    {
      z = y % x;
      y = x;
      x = z;
    }

  return y;
}



static inline dep_type
get_dep_type (const SD3_STRIDE& a, const SD3_STRIDE& b)
{
  dep_type dtype = dep_NO;

  if (!a.read) {
    if (!b.read)
      dtype = dep_WAW;
    else
      dtype = dep_WAR;
  }
  else {
    if (!b.read)
      dtype = dep_RAW;
    else
      dtype = dep_RAR;
  }

  return dtype;
}


inline void
SD3_STRIDE::init (PTR low, PTR high, ptrdiff_t distance, int size,
                  SD3_STATUS status, int read, int write,
                  int loop_id, int stmt, int ref)
{
  this->status = status;
  this->low = low;
  this->high = high;
  this->distance = distance;
  this->size = size;
  this->read = read;
  this->stmt = stmt;
  this->killed = 0;
  jump_distance = 0;

  if (distance == 0)
    num_strides = 1;
  else
    num_strides = abs (high - low)/ abs (distance) + 1;

  return;
}


bool
SD3_STRIDE::check_addr (PTR addr, int size, int stmt, int ref) const
{
  assert (status >= SD3_START && status < SD3_LAST);

  if (this->size != size)
    return false;

  bool ret = false;
  switch (status) {
  case SD3_START:
  case SD3_FIRST:
    ret = true;
  default:
    if (high + distance == addr)
      ret = true;
    else
      ret = false;
  }
    
  return ret;
}
  

bool
SD3_STRIDE::add_addr (PTR addr, int size, int stmt, int write)
{
  assert (status == SD3_START);
  if ( write )
    read = 0;
  else
    read = 1;

  status = SD3_FIRST;
  low = high = addr;
  this->stmt = stmt;
 
  return true;
}


bool
SD3_STRIDE::merge (const SD3_STRIDE& x)
{
  assert (status > SD3_START);
  

  if (status == SD3_FIRST) 
  {
    if (x.status == SD3_FIRST)
    {  
      status = SD3_LEARNED;
      if ( x.high < low )
        low = x.low;
      else
        high = x.high;
      distance = high - low;
      num_strides += x.num_strides;
      return true;
    }
    else
      return false;    
  }
  else if (status == SD3_LEARNED) 
  {    
    if (x.status == SD3_FIRST) 
    {
      if ( high < x.low )
      {
        int new_distance = x.low - high; 
        if ( distance != new_distance )
          return false;
        high = x.high;
        num_strides += x.num_strides;
        status = SD3_WEAK;
        return true;  
      }
      else if ( low > x.high )
      {
        int new_distance = low - x.high; 
        if ( distance != new_distance )
          return false;
        low = x.low;
        num_strides += x.num_strides;
        status = SD3_WEAK;
        return true;  
      }    
      else if ( low <= x.low && high >= x.high )
      {
        status == SD3_WEAK;
        return true;  
      }
        
      return false;

    }
    else if (x.status == SD3_LEARNED) 
    {
      if ( distance != x.distance )
        return false;
      /* overlapped */
      if ( high < x.low )
      {
        int new_distance = x.low - high; 
        if ( distance != new_distance )
          return false;
        high = x.high;
        num_strides += x.num_strides;
        status = SD3_WEAK;
        return true;  
      }
      else if ( low > x.high )
      {
        int new_distance = low - x.high; 
        if ( distance != new_distance )
          return false;
        low = x.low;
        num_strides += x.num_strides;
        status = SD3_WEAK;
        return true;  
      }    
      
      return false;
    }
    return false;

  }

  else if (status == SD3_WEAK) 
  {    
    if (x.status == SD3_FIRST) 
    {
      if ( high < x.low )
      {
        int new_distance = x.low - high; 
        if ( distance != new_distance )
        {
          if ( jump_distance == 0 || jump_distance == new_distance)
          {
            jump_distance = new_distance;
            high = x.high;
            num_strides += x.num_strides;
            return true;         
          }
          return false;
        }
        high = x.high;
        num_strides += x.num_strides;
        return true;  
      }
      else if ( low > x.high )
      {
        int new_distance = low - x.high; 
        if ( distance != new_distance )
        {
          if ( jump_distance == 0 || jump_distance == new_distance)
          {
            jump_distance = new_distance;
            low = x.low;
            num_strides += x.num_strides;
            return true;         
          }
          return false;
        }      
        low = x.low;
        num_strides += x.num_strides;
        return true;  
      }    
      else if ( low <= x.low && high >= x.high )
      {
        return true;  
      }
        
      return false;

    }
    else if (x.status == SD3_LEARNED )
    {
      SD3_STRIDE x_low;
      x_low.add_addr (x.low, 0, x.stmt, x.read);
      SD3_STRIDE x_high;
      x_high.add_addr (x.high, 0, x.stmt, x.read);
      if ( this->merge(x_low) && this->merge(x_high) )
        return true;
      return false;
    }

    else
    {
      if ( distance != x.distance )
        return false;
      /* overlapped */
      if ( low > x.low )
        low = x.low;
      if ( high < x.high )
        high = x.high;
      num_strides += x.num_strides;
      return true;
    }

  }
  
}



/* TODO: implement GCD-based dependence check!

   distance_1 * x + low_1 = distance_2 * y + low_2
   0 <= x <= iter_no
   0 <= y <= iter_no
*/

bool
SD3_STRIDE::compute_dependence (SD3_STRIDE& x, dep_kind dkind, int dloop) const
{
  if (x.killed)
    return false;

  int dstmt_before = stmt;
  int dstmt_after = x.stmt;
  dep_type dtype = get_dep_type (x, *this);

  if ( dtype == dep_RAR )
    return false;

  ptrdiff_t gcd_num;
  ptrdiff_t const_num;
  
  // TOOD: distance < 0

  if (distance == 0 && x.distance == 0) 
  {
    if (low != x.low)
      goto no_deps;
    
    if (dkind == LOOP_INDEPENDENT && dtype == dep_RAW)
      x.killed = 1;
  }
  else if (distance == 0) 
  {
    PTR addr = low;
    if (addr < x.low || addr > x.high)
      goto no_deps;

    PTR tmp = x.low;

    if (abs (addr - x.low) % x.distance)
      goto no_deps;
    
  }
  else if (x.distance == 0) 
  {
    PTR addr = x.low;
    if (addr < low  || addr > high )
      goto no_deps;
  
    if (abs (addr - low) % distance)
      goto no_deps;

    if (dkind == LOOP_INDEPENDENT && dtype == dep_RAW)
      x.killed = 1;
  }
  else
  {
    // distance != 0 && x.distance != 0
    gcd_num = sd3_gcd (distance, x.distance);
    assert (gcd_num);

    if (abs (x.low - low) % gcd_num)
      goto no_deps;

    // check killed
    if ( x.low >= low && x.high <= high )
    {
      if (dkind == LOOP_INDEPENDENT && dtype == dep_RAW)
        x.killed = 1;
    }
  }

  // After depends on before
  {
    DEPENDENCY dep(dstmt_after, dstmt_before, dkind, dtype, dloop);
    return sd3_record_dependence (dep);
  }

  no_deps: 
    dtype = dep_NO; 
    return false;
  
}


void
lentry::insert_stride (const SD3_STRIDE& stride)
{
  SD3_STRIDE_SET& stride_set = pending_stride_table [stride.stmt];
  SD3_STRIDE_SET strides;
  strides.insert(stride);
  merge_stride (stride_set, strides);
}

void
lentry::check_loop_independent_dependence (SD3_STRIDE& stride)
{
  /* check dependence with stride pending table */
  
  for (STRIDE_MAP::iterator hiter = pending_stride_table.begin(); 
        hiter != pending_stride_table.end(); ++hiter )
  {
    std::pair<SD3_STRIDE_SET::iterator, SD3_STRIDE_SET::iterator>
      range_pair = hiter->second.equal_range(stride);

    for ( SD3_STRIDE_SET::iterator iter = range_pair.first; iter != range_pair.second; ++iter )
    {
      const SD3_STRIDE &pending_stride = *iter; 
      if ( !pending_stride.overlaps (stride) )
        break;
      pending_stride.compute_dependence (stride, LOOP_INDEPENDENT, loop_id);
    }

  }

}


void
lentry::check_loop_carried_dependence (SD3_STRIDE& stride)
{
  /* check dependence with stride history table */ 

  for (STRIDE_MAP::iterator hiter = history_stride_table.begin(); 
        hiter != history_stride_table.end(); ++hiter )
  {
    std::pair<SD3_STRIDE_SET::iterator, SD3_STRIDE_SET::iterator>
      range_pair = hiter->second.equal_range(stride);

    for ( SD3_STRIDE_SET::iterator iter = range_pair.first; iter != range_pair.second; ++iter )
    {
      const SD3_STRIDE &history_stride = *iter; 
      if ( !history_stride.overlaps (stride) )
        break;

      history_stride.compute_dependence (stride, LOOP_CARRIED, loop_id);
    }

  }
}



bool 
lentry::merge_stride (SD3_STRIDE_SET &history, SD3_STRIDE_SET &pending)
{
  if (history.empty())
  {
    history = pending;
    return true;
  }

  bool all_merged = true;
  int n = 0;
  std::set<SD3_STRIDE>::iterator iter = pending.begin ();
  for (; iter != pending.end (); ++iter)
  {
    SD3_STRIDE pending_stride = *iter;
    std::set<SD3_STRIDE>::iterator hiter = history.begin ();
    bool merged = false;
    for (; hiter != history.end (); ++hiter) 
    {
      SD3_STRIDE stride = *hiter;
      if (stride.merge (pending_stride)) 
      {
        history.erase (hiter);
        history.insert (stride);
        merged = true;
        break;
      }
      else if (pending_stride.merge (stride)) 
      {
        history.erase (hiter);
        history.insert (pending_stride);
        merged = true;
        break;
      }
      else
      {
        n++;
      }
    }

    if (!merged)
    {
      /* Anyway we merge a single stride as a dependent stride this time. 
         Next time we won't recognize any stride for this statement
       */
      SD3_STRIDE_SET::iterator t = history.insert (pending_stride);  
      assert ( t != history.end() );
      
    }

    if ( all_merged && !merged )
      all_merged = false;
  
  }

  return true;
  return all_merged;
}




void
lentry::merge_stride_table (void)
{
  std::map <int, SD3_STRIDE_SET>::iterator iter =  pending_stride_table.begin ();
  for (; iter != pending_stride_table.end (); ++iter)
  {
    SD3_STRIDE_SET  &history = history_stride_table [iter->first];
    merge_stride (history, iter->second);
  }
  
  /* Clear pending table */
  pending_stride_table.clear();
}



void
lentry::compute_dependence (void)
{

  /* Check stride table */
  for (STRIDE_MAP::iterator iter = pending_stride_table.begin ();
       iter != pending_stride_table.end (); ++iter) 
  {    
    for ( SD3_STRIDE_SET::iterator siter = iter->second.begin(); siter != iter->second.end(); ++siter)
    {
      SD3_STRIDE &pending_stride = const_cast<SD3_STRIDE &>(*siter); 
      check_loop_carried_dependence (pending_stride);
    }
  }
      
}


void
lentry::merge_pending_table (void)
{
  merge_stride_table ();
}


static void sd3_profile_initialize (void) __attribute__ ((noinline));

extern bool time_start_p;

static void
sd3_profile_initialize (void)
{
  initialized = 1;
  loop_stack_capacity = SIZE - 1;
  loop_stack = (lentry **) malloc (sizeof (lentry*) * loop_stack_capacity);
  loop_stack[0] = new lentry;
  loop_stack[0]->kind = pk_not_profile;
  sp = 0;
  time_start_p = false;

  fp = stdout;

  clktck = sysconf (_SC_CLK_TCK);

//  generate_loop_table ();

//  generate_not_profile_loops ();

//  generate_memory_op_table ();

}





void
sd3_stride_insert (lentry *loop, int stmt, const SD3_STRIDE* stride)
{
  loop->insert_stride (*stride);
}



/* Merge loop's history to father's pending */
void 
sd3_tranfer_loop(lentry* father, lentry* loop)
{

  for (STRIDE_MAP::iterator iter = loop->history_stride_table.begin (); 
        iter != loop->history_stride_table.end (); ++iter) 
  {

    int memop = iter->first;

    /* check dependence with father's stride pending table */
    for (SD3_STRIDE_SET::iterator fiter = iter->second.begin(); 
          fiter != iter->second.end(); ++fiter )
    {
       SD3_STRIDE &pending_stride = const_cast<SD3_STRIDE &>(*fiter); 
       father->check_loop_independent_dependence (pending_stride);
    }

    /* merge strides */
    SD3_STRIDE_SET  &strides = father->pending_stride_table[memop];
    father->merge_stride (strides, iter->second);
  }    

        
}


static void
sd3_clear_loop (lentry * loop)
{

  /* clear point history table */
  struct pair_hash_entry *cur_addr=NULL;
  struct pair_hash_entry *hashaddr=loop->history_point_table;
  int i=0;
  while (hashaddr) 
  {
    cur_addr = hashaddr;         /* copy pointer to first item     */
    cur_addr->memops.clear();
    HASH_DEL (hashaddr, cur_addr); /* delete; hashaddr advances to next */
    delete cur_addr;  
    i++;
  }

  if ( max_hash_size < i )
    max_hash_size = i;

  /* clear stride history table */  
  for (STRIDE_MAP::iterator iter = loop->history_stride_table.begin();
        iter != loop->history_stride_table.end(); ++iter )
    iter->second.clear();

  for (STRIDE_MAP::iterator iter = loop->pending_stride_table.begin();
        iter != loop->pending_stride_table.end(); ++iter )
    iter->second.clear();

  loop->history_stride_table.clear();
  loop->pending_stride_table.clear();
  
}



/* addr - address;
   size - the bit size of memory to be accessed
   flag - read/write, if address read (0) or written (1) ; 
   id - memory operation id or slice id; */

extern "C" void
__sd3_print_address (PTR addr, int size, char flag, int id)
{
  if (!initialized)
    sd3_profile_initialize ();

  assert (Top ());
  
  if (Top ()->kind == pk_not_profile)
    return;

  SD3_STRIDE stride;  
  stride.add_addr (addr, size, id, flag);

  Top()->check_loop_independent_dependence (stride);

  sd3_stride_insert (Top (), id, &stride);

  return;
}



/* entry-loop enter/exit; loc - loop global id; s -  */

extern "C" void
__sd3_print_edge (int loop_id, int entry)
{
  if (!initialized)
    sd3_profile_initialize ();

  if (entry) {
    if (loop_stack_capacity = Length () + 2) {
      loop_stack_capacity = (loop_stack_capacity << 1) - 1;
      loop_stack =
        (lentry **) realloc (loop_stack,
                             sizeof (lentry*) * loop_stack_capacity);
    }

    if (Top ()->loop_id != loop_id) {
      lentry *loop = new lentry;
      loop->loop_id = loop_id;
      loop->iter_no = 1;
      if (Empty ()) {
        if (not_profile_loops.find (loop_id) !=
            not_profile_loops.end ())
          loop->kind = pk_not_profile;

        stmt_count = 0;
        ave_dep_num = 0;
 //       fprintf (fp,
 //                "\n\n\n ======= Entering LOOP NEST %d ===\n",
 //                loop_id);
      }
      else {
        if (Top ()->kind == pk_not_profile)
          loop->kind = pk_not_profile;
        else if (not_profile_loops.find (loop_id) !=
                 not_profile_loops.end ()) {
          if (Top ()->kind == pk_not_profile)
            loop->kind = pk_not_profile;
          // profile but not need checking dependence
          else
            loop->kind = pk_profile;
        }
      }

      Push (loop);
    }

    // New iteration in same loop_id
    else {

      Top ()->compute_dependence ();
      Top ()->merge_pending_table ();
      Top ()->iter_no++;
    }
  }
  else {

    lentry *loop = Top ();

    // when loop->loop_id != loop_id ???
    if (loop->loop_id == loop_id) {
      // loop_infos[loop_id].itercount += loop->iter_no;

      /* check for the last iteration */
      Top ()->compute_dependence ();
      Top ()->merge_pending_table ();
      
      /* transfer hash table */
      if (Length () > 1)
        sd3_tranfer_loop (Elem (Length () - 1), loop);

      sd3_clear_loop (loop);
      Pop ();


      if (0 && Empty ()) {

        fprintf (fp,
                 "\n ========================Exiting of LOOP NEST %d==========================\n",
                 loop_id);
        fprintf (fp, "max_dep_num = %d\n", max_dep_num);
        fprintf (fp, "ave_dep_num = %d\n", ave_dep_num);
        fprintf (fp, "addr count = %lld\n", addr_count);
        fprintf (fp, "stmt count = %lld\n", stmt_count);
        
        
        fprintf (fp, "profile_count=%lld\n", profile_count);
        fprintf (fp, "time for print range is : %7.4f sec\n",
                 Time_Profile[0] / (double) clktck);
        fprintf (fp, "time for insert_and_merge is : %7.4f sec\n",
                 Time_Profile[1] / (double) clktck);
        fprintf (fp, "time for transfer loop is : %7.4f sec\n",
                 Time_Profile[2] / (double) clktck);
        fprintf (fp, "time for check dependence is : %7.4f sec\n",
                 Time_Profile[3] / (double) clktck);
      }
    }
  }

  return;
}


extern "C" void
__sd3_print_exit () 
{
  print_dependencies ();

  fprintf (stderr, "exit funtion ?\n", __func__);

  fprintf (stderr, "max_hash_size = %d\n", max_hash_size);

  exit(0);
}

