package jazz.util;

///////////////////////////////////////////////////////////////////////////////
//
//                          Generic linked lists
//
///////////////////////////////////////////////////////////////////////////////

public abstract class List<+T> {
  // The empty list
  public static nil: Nil<U> {U} = new Nil();

  // List constructor
  public cons(head: T): Cons<T>;

  // List reverse
  public reverse(): alike<T>;
  
  // List concatenation
  public concat(l: alike<T>): alike<T>;

  // List iterator
  public map<U>(f: fun(T): U): alike<U>;

  // List folding
  public fold<U>(f0: U, f1: fun(T): U, f2: fun(T, U): U): U;

  // Length of a list
  public length(): int;
  
  // Element access
  public nth(n: int): T;

  // List formatting parameters. By default, lists are displayed in a
  // Lisp-like fashion, i.e., "(1 2 3 4)".
  public static dynamic lparen: String = "(";
  public static dynamic separator: String = " ";
  public static dynamic rparen: String = ")";
}

///////////////////////////////////////////////////////////////////////////////
//
//                               Implementation
//
///////////////////////////////////////////////////////////////////////////////

toString@List() = format("%s%s%s", lparen, fold("", f1, f2), rparen)
{
  fun f1(t) = t.toString();
  fun f2(t, u) = format("%a%s%s", t, separator, u);
}

cons@List(head) = new Cons(head = head, tail = this);

map@Nil(f) = nil;
map@Cons(f) = tail.map(f).cons(f(head));

concat@Nil(l) = l;
concat@Cons(l) = tail.concat(l).cons(head);

reverse@Nil<T>() = this;
reverse@Cons<T>() = reverse(tail, nil.cons(head))
{
  fun reverse(l: List<T>, r: Cons<T>): Cons<T>;
  reverse(l@Nil, r) = r;
  reverse(l@Cons, r) = reverse(l.tail, r.cons(l.head));
}

fold<U>@List<T>(f0, f1, f2) = fold(this)
{
  fun fold(l: List<T>): U;
  fold(l@Nil) = f0;
  fold(l@Cons) = (l.tail instanceof Nil ?
                  f1(l.head) : f2(l.head, fold(l.tail)));
}

length@Nil() = 0;
length@Cons() = 1 + tail.length();

nth@Nil(n) = error("nth(nil)");
nth@Cons(n) = (n == 0 ? head : tail.nth(n-1));