package jazz.circuit.esterel;

//////////////////////////////////////////////////////////////////////////////
//
//                  Compilation of Pure Esterel programs
//
//////////////////////////////////////////////////////////////////////////////

import jazz.circuit.*;
import jazz.circuit.expr.*;
import jazz.circuit.expr.BoolExpr.*;
import jazz.unsafe.Hashtable;

require jazz.util;

//////////////////////////////////////////////////////////////////////////////
//
//                      Compilation environment
//
//////////////////////////////////////////////////////////////////////////////

final class Env {
  c:  BoolExpr;              // Incoming control wire
  s:  BoolExpr;              // Selection wire
  a:  BoolExpr;              // Activation wire
  i:  BoolExpr;              // Inhibition wire
  c0: BoolExpr;              // Termination wire
  c1: BoolExpr;              // Halting wire
  n:  int;                   // Number of exit wires
  C:  BoolExpr[n];           // Exit wires
  S:  fun(String): BoolExpr; // Input and output signal wires
  T:  fun(String): int;      // Trap environment
}

//////////////////////////////////////////////////////////////////////////////
//
//                     Compilation of a module
//
//////////////////////////////////////////////////////////////////////////////

compile@Module()(innets) = ( sigtovar.assoc(in)(insig, invars)
                           ; sigtovar.assoc(out)(outsig, outvars)
                           ; stmt.compile(env)
                           ; instanciate([i -> S(outsig[i])])
                                        (invalues, Net.reg, Net.constant))
{
  // Input signals
  in = inputs.length();
  insig = [i -> inputs.nth(i)];

  // Output signals
  out = outputs.length();
  outsig = [i -> outputs.nth(i)];
  
  // Boot
  boot = newInput(0);

  // Zero
  zero = newInput(1);
  
  // One
  one = newInput(2);
  
  // Dummy local wire for the initial environment
  dummy = newLocal();

  // Association of input signals names to Net expression input variables
  invars = [i -> newInput(i + 3)];
  
  // Association of output signal names to local Net expression variables
  outvars = [i -> newLocal()];

  // Hashtable mapping signal names to Net variables
  sigtovar = Hashtable.create();
  fun S(name) =
  (BoolExpr) sigtovar.get(name, error("undefined signal: %s", name));
                     
  // Empty trap environment
  fun T(sig) = error("trap signal \"%s\" undefined", sig);

  // Input values of the circuit
  invalues[0] = ~Net.reg(Net.constant(#(1)));
  invalues[1] = Net.constant(#(0));
  invalues[2] = Net.constant(#(1));
  invalues[3..] = innets;
  
  // Initial environment
  env = new Env(c = boot, s = one,
                a = one, i = zero,
                c0 = dummy, c1 = dummy,
                n = 0, C = [],
                S = S, T = T);
}

//////////////////////////////////////////////////////////////////////////////
//
//                    Compilation rules for statements
//
//////////////////////////////////////////////////////////////////////////////

compile@NothingStmt(env) = env.c0.setOr(env.c);

compile@HaltStmt(env) = ( env.c1.setOr(x)
                        ; env.s.setOr(s')
                        ; s'.setEq(reg(x & ~env.i)) )
{
  s' = newLocal();
  x = env.c | env.a & env.s;
}

compile@SequenceStmt(env) = ( stmt1.compile(env1)
                            ; stmt2.compile(env2) )
{
  c' = newLocal();
  env1 = env.clone(c0 = c');
  env2 = env.clone(c = c');
}

compile@EmitStmt(env) = ( S.setOr(env.c)
                        ; env.c0.setOr(env.c) )
{
  S = env.S(sig);
}
  
compile@LoopStmt(env) = ( c'.setEq(env.c)
                        ; stmt.compile(env'))
{
  c' = newLocal();
  env' = env.clone(c = c', c0 = c');
}
  
compile@PresentStmt(env) = ( c1.setEq(env.c & S)
                           ; c2.setEq(env.c & ~S)
                           ; stmt1.compile(env1)
                           ; stmt2.compile(env2) )
{
  S = env.S(sig);
  
  c1 = newLocal();
  c2 = newLocal();

  env1 = env.clone(c = c1);
  env2 = env.clone(c = c2);
}
  
compile@WatchStmt(env) = ( env.s.setOr(s')
                         ; a'.setEq(env.a & env.s & ~S)
                         ; env.c0.setOr(env.a & env.s & S)
                         ; stmt.compile(env') )
{
  S = env.S(sig);
  
  s' = newLocal();
  a' = newLocal();

  env' = env.clone(s = s', a = a');
}
  
compile@TrapStmt(env) = stmt.compile(env')
{
  env' = env.clone(n = env.n + 1,
                   C = [env.c0, i -> env.C[i-1]],
                   T = T);
  fun T(sig) = (sig == sig) ? 0 : 1 + env.T(sig);
}

compile@ExitStmt(env) = env.C[env.T(sig)].setOr(env.c);
  
compile@SignalStmt(env) = stmt.compile(env);

compile@ParallelStmt(env) = ( env.s.setOr(s')
                            ; i'.setEq(n <= 3 ? env.i : env.i | p[2])
                            ; eval(n)([k -> c[k].setOr(cdef[k])])
                            ; eval(n - 1)([k -> p[k].setEq(pdef[k])])
                            ; stmt1.compile(env')
                            ; stmt2.compile(env') )
{
  // Locals
  s' = newLocal();
  i' = newLocal();
  c' = [i -> newLocal()];
  p = [i -> newLocal()];
  
  // Definitions
  n = env.n + 2;
  c = [env.c0, env.c1, i -> env.C[i - 2]];
  
  cdef[n - 1] = c'[n - 1];
  for (k < n - 1) {
    cdef[k] = c'[k] & ~p[k];
  }
  
  pdef[n - 2] = c'[n - 1];
  for (k < n - 2) {
    pdef[k] = c'[k + 1] & ~p[k + 1];
  }

  // New environment
  env' = env.clone(s = s', i = i', c0 = c'[0], c1 = c'[1], C = c'[2..]);

  // Sequential evaluation
  fun eval(n)(x) = eval(0)
  {
    fun eval(k) = (k == n) ? () : (x[k]; eval(k+1));
  }
}