open StdLabels

type 'a window =
  | Min of 'a
  | Max of 'a
  | Counter
  | Previous of 'a
  | Sum of 'a
[@@deriving show, eq, ord]

type 'a t =
  | Empty
  | Expr of 'a t
  | Literal of string
  | Integer of string
  | Path of 'a
  | Concat of 'a t list
  | Function of string * 'a t list
  | Nvl of 'a t list
  | Join of string * 'a t list
  | Window of ('a t window * 'a t list * 'a t list)
  | BOperator of binary_operator * 'a t * 'a t
  | GEquality of binary_operator * 'a t * 'a t list
  | Function' of funct * 'a t list
[@@deriving show, eq, ord]

and binary_operator =
  | Equal
  | Different
  | Add
  | Minus
  | Division
  | LT
  | GT
  | And
  | Or
[@@deriving ord]

and funct =
  | Cmp
  | Trim
  | Upper

let name_of_function : funct -> string = function
  | Upper -> "UPPER"
  | Trim -> "TRIM"
  | Cmp -> "CMP"

let function_of_name : 'a t list -> string -> 'a t =
 fun param f ->
  match (String.lowercase_ascii f, param) with
  | "nvl", _ -> Nvl param
  | "join", Literal sep :: tl -> Join (sep, tl)
  | "join", Empty :: tl -> Join ("", tl)
  | "upper", _ -> Function' (Upper, param)
  | "trim", _ -> Function' (Trim, param)
  | "cmp", _ -> Function' (Cmp, param)
  (* Branch function *)
  | ("if" as fn), [ _; _; _ ] -> Function (fn, param)
  (* Integer functions *)
  | ("abs" as fn), _
  | ("int" as fn), _
  (* String functions *)
  | ("concat" as fn), _
  | ("match" as fn), _
  | ("substring" as fn), _
  (* Date functions *)
  | ("date" as fn), _
  | ("year" as fn), _ -> Function (fn, param)
  | _ ->
      (*Function (other, param)*)
      raise (ImportErrors.UnknowFunction f)

let name_of_operator = function
  | Equal -> "="
  | Different -> "<>"
  | Add -> "+"
  | Minus -> "-"
  | Division -> "/"
  | LT -> "<"
  | GT -> ">"
  | And -> " and "
  | Or -> " or "

let name_of_window = function
  | Min _ -> "min"
  | Max _ -> "max"
  | Counter -> "counter"
  | Previous _ -> "previous"
  | Sum _ -> "sum"

let map_window : f:('a -> 'b) -> 'a window -> 'b window =
 fun ~f -> function
  | Min t -> Min (f t)
  | Max t -> Max (f t)
  | Counter -> Counter
  | Previous t -> Previous (f t)
  | Sum t -> Sum (f t)

(** Extract the kind of the window function from the given name. *)
let window_of_name name opt =
  match (name, opt) with
  | "min", Some p -> Min p
  | "max", Some p -> Max p
  | "counter", None -> Counter
  | "previous", Some p -> Previous p
  | "sum", Some p -> Sum p
  | _other -> raise Not_found

let rec cmp : ('a -> 'a -> int) -> 'a t -> 'a t -> int =
 fun f e1 e2 ->
  match (e1, e2) with
  | Empty, Empty -> 0
  | Literal l1, Literal l2 -> String.compare l1 l2
  | Integer l1, Integer l2 -> String.compare l1 l2
  | Path p1, Path p2 -> f p1 p2
  | Concat elems1, Concat elems2 | Nvl elems1, Nvl elems2 ->
      List.compare ~cmp:(cmp f) elems1 elems2
  | Function (n1, elems1), Function (n2, elems2) ->
      let name_cmp = String.compare n1 n2 in
      if name_cmp = 0 then List.compare ~cmp:(cmp f) elems1 elems2 else name_cmp
  | Window (s1, l11, l12), Window (s2, l21, l22) -> (
      match compare_window (cmp f) s1 s2 with
      | 0 ->
          let l1_cmp = List.compare ~cmp:(cmp f) l11 l21 in
          if l1_cmp = 0 then List.compare ~cmp:(cmp f) l12 l22 else l1_cmp
      | other -> other)
  | BOperator (n1, arg11, arg12), BOperator (n2, arg21, arg22) -> begin
      match compare_binary_operator n1 n2 with
      | 0 -> begin
          match cmp f arg11 arg21 with
          | 0 -> cmp f arg12 arg22
          | other -> other
        end
      | other -> other
    end
  (* Any other case *)
  | other1, other2 -> Stdlib.compare other1 other2

let fold_values : f:('b -> 'a -> 'b) -> init:'b -> 'a t -> 'b =
 fun ~f ~init expression ->
  let rec _f acc = function
    | Empty | Literal _ | Integer _ -> acc
    | Expr e -> _f acc e
    | Path p -> f acc p
    | Concat pp | Function' (_, pp) | Function (_, pp) | Nvl pp | Join (_, pp)
      -> List.fold_left ~f:(fun acc a -> _f acc a) ~init:acc pp
    | Window (window_f, pp1, pp2) ->
        (* Each window function can have a distinct parameter first. *)
        let acc' =
          match window_f with
          | Counter -> acc
          | Min key | Max key | Previous key | Sum key -> _f acc key
        in
        let eval1 = List.fold_left ~f:(fun acc a -> _f acc a) ~init:acc' pp1 in
        List.fold_left ~f:(fun acc a -> _f acc a) ~init:eval1 pp2
    | BOperator (_, arg1, arg2) -> _f (_f acc arg1) arg2
    | GEquality (_, arg1, arg2) ->
        let eval1 = List.fold_left ~f:(fun acc a -> _f acc a) ~init:acc arg2 in
        _f eval1 arg1
  in
  _f init expression

let map : type a b. f:(a -> b) -> a t -> b t =
 fun ~f expression ->
  let rec map = function
    | Expr e -> Expr (map e)
    | Empty -> Empty
    | Literal s -> Literal s
    | Integer i -> Integer i
    | Path p -> Path (f p)
    | Concat pp -> Concat (List.map ~f:map pp)
    | Function' (name, pp) -> Function' (name, List.map ~f:map pp)
    | Function (name, pp) -> Function (name, List.map ~f:map pp)
    | Nvl pp -> Nvl (List.map ~f:map pp)
    | Join (sep, pp) -> Join (sep, List.map ~f:map pp)
    | Window (window_f, pp1, pp2) ->
        let w = map_window ~f:map window_f in
        Window (w, List.map ~f:map pp1, List.map ~f:map pp2)
    | BOperator (n, arg1, arg2) -> BOperator (n, map arg1, map arg2)
    | GEquality (n, arg1, args) -> GEquality (n, map arg1, List.map ~f:map args)
  in
  map expression