package jazz.circuit.bdd;

///////////////////////////////////////////////////////////////////////////////
//
//                    Binary Decision Diagrams (BDD's)
//
// We use a class Bdd to wrap-up the class _Bdd since it is true that
// "_Bdd abstracts BooleanAlgebra" but it is not the case that "_Bdd
// implements BooleanAlgebra" (since the BooleanAlgebra interface requires
// that all its operators (e.g., "&", "|", "^") be typed as follows
//
//                <T: BooleanAlgebra>(x: T, y: T): T
//
// which is not the case of the implementation below for _Bdd). This wrapper
// class will not be necessary when interfaces can be abstracted. Bdds
// implement equality and the BooleanAlgebra interface.
//
///////////////////////////////////////////////////////////////////////////////

import jazz.circuit.*;
import jazz.unsafe.Hashtable;
import jazz.unsafe.Counter;

public final class Bdd implements BooleanAlgebra {
  // Creates a new set of ordered variables
  public static newVars(n: int)(names: String[n]): Bdd[n];
  
  // Creation of a bdd from a function computing the truth table. This method
  // evaluates "f" two to the power of "n" times.
  public static fromFunction(n: int)
                            (vars: Bdd[n])
                            (f: fun(boolean[]): boolean): Bdd;

  // Returns a function computing the m bdds (with sharing) in a given boolean
  // algebra (e.g., two periodics, two nets, etc.), given the number of
  // arguments "n" of the function, an array giving the "n" variables over
  // which the bdd is computed (these variables are, in this order, the
  // arguments of the result function), as well as the proper values of "zero"
  // and "one" for the target boolean algebra.
  public static toFunction<T: BooleanAlgebra>(n: int)
                                             (vars: Bdd[n])
                                             (m: int)
                                             (bdds: Bdd[m])
                                             (zero: T, one: T) : fun(T[]): T[];

  // Bdd constants
  public static zero: Bdd;
  public static one: Bdd;

  // Private definition of the bdd
  bdd: _Bdd;
}


///////////////////////////////////////////////////////////////////////////////
//
//                            Implementation
//
///////////////////////////////////////////////////////////////////////////////

abstract class _Bdd {
  id: int;
  definition(): String;
}

// Constants zero
class _Zero extends _Bdd {
}

// Constants one
class _One extends _Bdd {
}

// Variables
class _Var extends _Bdd implements Comparable {
  name: String;
}

// Internal bdd nodes
class _Node extends _Bdd {
  // The variables used to differentiate the two branches
  v: _Var;

  // The two branches "var = false" and "var = true"
  tt: _Bdd;
  ff: _Bdd;
}

// Class constants
Bdd.zero = new Bdd(bdd = _zero);
Bdd.one = new Bdd(bdd = _one);

// Global variables and constants
var table = Hashtable.create();
var counter = Counter.create(0);
var _zero = new _Zero(id = counter.next());
var _one = new _One(id = counter.next());

// Equality of bdds
Builtin.(==)(b1@Bdd, b2@Bdd) = (b1.bdd == b2.bdd);

// Display
toString@Bdd() = bdd.definition();

// Creates a new variable
Bdd.newVars(n)(names) = vars
{
  vars = [new Bdd(bdd = newVar(names[0])),
          i -> (vars[i-1].bdd; new Bdd(bdd = newVar(names[i])))];
}

// Creation of a bdd from a truth table represented by a function
Bdd.fromFunction(n)(vars)(fn) = new Bdd(bdd = create(n)([]))
{
  // Leaf node
  fun leaf(args) = fn(args) ? _one : _zero;

  // Update the argument array with the value of one more variable
  fun update(k)(args, v) = [i -> i == k ? v : args[i]];

  // Create the bdd
  fun create(k)(args) = (k == 0 ? leaf(args) :
                         cond((_Var) vars[n - k].bdd,
                              create(k-1)(update(n - k)(args, true)),
                              create(k-1)(update(n - k)(args, false))));
}

// Boolean algebra operators
Builtin.(&)(x@Bdd, y@Bdd) = new Bdd(bdd = and(x.bdd, y.bdd));
Builtin.(|)(x@Bdd, y@Bdd) = new Bdd(bdd = or(x.bdd, y.bdd));
Builtin.(~)(x@Bdd) = new Bdd(bdd = not(x.bdd));
Builtin.(^)(x@Bdd, y@Bdd) = new Bdd(bdd = or(and(x.bdd, not(y.bdd)),
                                     and(not(x.bdd), y.bdd)));
Builtin.cond(x@Bdd, y@Bdd, z@Bdd) = new Bdd(bdd = or(and(x.bdd, y.bdd),
                                                     and(not(x.bdd), z.bdd)));

// Return a function computing the bdd over some boolean algebra
Bdd.toFunction<T>(n)(vars)(m)(bdds)(z, o) = fn
{
  // The function computing the bdd
  fun fn(args) = (defs.assoc(n)(ids, args) ; definitions) {
    // Hash table to track nodes and variables with known definitions
    defs = Hashtable.create();

    // Array of variable ids
    ids = [i -> vars[i].bdd.id];

    // Definitions of the bdds
    definitions = [i -> definition(bdds[i].bdd)];
    
    // Definition of a bdd
    fun definition(b: _Bdd): T;
    definition(b@_Zero) = z;
    definition(b@_One) = o;
    definition(v@_Var) = defs.get(v.id, error("undefined variable"));
    definition(n@_Node) = defs.get(n.id, Builtin.cond(definition(n.v),
                                                      definition(n.tt),
                                                      definition(n.ff)));
  }
}

// Equality of bdds
Builtin.(==)(b1@_Bdd, b2@_Bdd) = (b1.id == b2.id);

// Variable ordering
Builtin.(<=)(v1@_Var, v2@_Var) = (v1.id <= v2.id);
Builtin.(==)(v1@_Var, v2@_Var) = (v1.id == v2.id);

toString@_Bdd() = format("N%d", id);

definition@_Bdd() = def(this)
{
  // Name after renumbering
  fun name(n: _Bdd): String;

  // Display of a bdd node
  fun str(b: _Bdd): String;
  
  // Local variables used to compute the bdd
  fun locals(b: _Bdd): String;

  // Definition of a bdd
  fun def(b: _Bdd): String;
  
  // Hash table to track bdds that have already been displayed
  var displayed = Hashtable.create();

  // Counter used to renumber the nodes
  var counter = Counter.create(0);

  // Renumbering
  var number = Hashtable.create();

  name(n@_Bdd) = format("n%d", number.get(n.id, counter.next()));
  
  def(b@_Bdd) = str(b);
  def(n@_Node) = format("%s {%s\n}", name(n), locals(n));

  str(c@_Zero) = "false";
  str(c@_One) = "true";
  str(v@_Var) = format("%s", v.name);
  str(n@_Node) = name(n);

  locals(b@_Bdd) = "";
  locals(n@_Node) = disp ? (nn; ntt; ltt; nff; lff
                           ; format("\n %s = %s ? %s : %s;%s%s",
                                    nn, n.v.name, ntt,
                                    nff, ltt, lff)) : ""
  {
    // The value "disp" is true iff the key "n.id" has been put in the table.
    disp = displayed.check(n.id, 0);

    // Use these definitions to force the evaluation order (and the numbering)
    nn = name(n);
    ntt = str(n.tt);
    nff = str(n.ff);
    ltt = locals(n.tt);
    lff = locals(n.ff);
  }

}

// Creates a new variable (make sure the ids are assigned in sequence)
fun newVar(name) = (id; new _Var(name = name, id = id))
{
  id = counter.next();
}

// Creates or retrieves the appropriate node
fun cond(v:_Var, tt: _Bdd, ff: _Bdd) =
  (ff == tt ? ff : (ff == _zero && tt == _one) ? v : table.get(key, n))
{
  // Key used to lookup the table of already built nodes
  key = (v.id << 32) | (tt.id << 16) | (ff.id);
  
  // New node
  n = new _Node(id = counter.next(), v = v, tt = tt, ff = ff);
}

// And
fun and(x: _Bdd, y: _Bdd): _Bdd;
and(x@_Zero, y@_Bdd) = x;
and(x@_One, y@_Bdd) = y;
and(x@_Var, y@_Bdd) = and(y, x);
and(x@_Var, y@_Var) = (x == y ? x :
                       x < y ? cond(x, y, _zero) :
                       cond(y, x, _zero));
and(x@_Node, y@_Bdd) = and(y, x);
and(x@_Node, y@_Var) =
  (x.v == y ? cond(x.v, x.tt, _zero) :
   x.v < y ? cond(x.v, and(x.tt, y), and(x.ff, y)) :
   cond(y, x, _zero));
and(x@_Node, y@_Node) =
  (x.v == y.v ?
   cond(x.v, and(x.tt, y.tt), and(x.ff, y.ff)) :
   x.v < y.v ?
   cond(x.v, and(x.tt, y), and(x.ff, y)) :
   cond(y.v, and(x, y.tt), and(x, y.ff)));

// Or
fun or(x: _Bdd, y: _Bdd): _Bdd;
or(x@_Bdd, y@_Bdd) = not(and(not(x), not(y)));

// Not
fun not(x: _Bdd): _Bdd;
not(x@_Zero) = _one;
not(x@_One) = _zero;
not(x@_Var) = cond(x, _zero, _one);
not(x@_Node) = cond(x.v, not(x.tt), not(x.ff));