aboutsummaryrefslogtreecommitdiff
path: root/content/resources/catalog.ml
blob: e6243156f511cc90dc78ccbd0b9fb1ff6bd78d6e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
exception TypeError
exception RegisteredFunction

(*** Type definitions *)

type _ typ =
  | Bool: bool typ
  | Int: int typ

let t_bool= Bool
let t_int = Int

(* Encode type equality *)
type (_,_) eq = Eq : ('a,'a) eq

let eq_typ: type a b. a typ -> b typ -> (a, b) eq = fun a b ->
  begin match a, b with
  | Bool, Bool -> Eq
  | Int, Int -> Eq
  | _ -> raise TypeError
end

let print_typ: type a. Buffer.t -> a typ -> unit = fun printer typ -> match typ with
  | Bool -> Printf.bprintf printer "Bool"
  | Int  -> Printf.bprintf printer "Int"

(*** Values definitions *)

type 'a value =
  | Bool: bool -> bool value
  | Int: int   -> int  value

(** Get the value out of the box *)
let get_value_content: type a. a value -> a = function
  | Bool b -> b
  | Int n -> n

(** Create a value from a known type and an unboxed value *)
let build_value: type a. a typ -> a -> a value = begin fun typ content ->
  match typ, content with
    | Bool, x -> Bool x
    | Int, x -> Int x
end

(* Extract the type from a boxed value *)
let type_of_value: type a. a value -> a typ = function
  | Bool b -> Bool
  | Int n -> Int

type result =
  | Result : 'a value -> result

(** Create a result from a known type and a value *)
let inject: type a. a typ -> a -> result = fun typ res ->
  Result (build_value typ res)

(** Catalog for all functions *)
module C = struct

  type _ sig_typ =
    | T1: 'a typ -> 'a sig_typ
    | T2: 'a typ * 'b typ -> ('a * 'b) sig_typ

  let eq_sig_typ: type a b. a sig_typ -> b sig_typ -> (a, b) eq = fun a b ->
    begin match a, b with
    | T1(a), T1(b) -> eq_typ a b
    | T2(a, b), T2(c, d) ->
        begin match (eq_typ a c), (eq_typ b d) with
          | Eq, Eq -> Eq
        end
    | _ -> raise TypeError
    end

  let print_sig_typ: type a. Buffer.t -> a sig_typ -> unit = begin fun printer typ -> match typ with
    | T1 a -> Printf.bprintf printer "(%a)"
        print_typ a
    | T2 (a, b) -> Printf.bprintf printer "(%a, %a)"
        print_typ a
        print_typ b
  end

  type t_signature =
    | Sig : 'a sig_typ -> t_signature

  type t_function =
    | Fn1: 'a typ * 'b typ * ('a -> 'b) -> t_function
    | Fn2: ('a * 'b) sig_typ * 'c typ * ('a -> 'b -> 'c) -> t_function

  module Catalog = Map.Make(
    struct
      type t = string * t_signature
      let compare = Pervasives.compare

    end
  )

  let (catalog:t_function Catalog.t ref) = ref Catalog.empty

  let register name signature f = begin
    let name' = String.uppercase_ascii name in
    catalog := Catalog.add (name', signature) f !catalog
  end

  let find_function name signature = begin
    Catalog.find (String.uppercase_ascii name, (Sig signature)) !catalog
  end

  let print_error: type a. string -> a sig_typ -> unit = begin fun name signature ->
    let buffer = Buffer.create 16 in
    print_sig_typ buffer signature;

    Printf.printf "There is no function '%s' with signature %s\n"
      name
      (Buffer.contents buffer);
  end

  let eval1 name (Result p1) = begin
    let signature = type_of_value p1 in
    try
      begin match find_function name (T1 signature) with
      | Fn1 (fn_sig, returnType, f) ->
        (* We check the type equality between the function signature and the parameters type *)
        begin match eq_typ fn_sig signature with Eq ->
          inject returnType (f (get_value_content p1))
        end
      | _ -> raise Not_found
      end
    with Not_found ->
      print_error name (T1 signature);
      raise Not_found
  end

  let eval2 name (Result p1) (Result p2) = begin
    let signature = T2 ((type_of_value p1), (type_of_value p2)) in
    try
      begin match find_function name signature with
      | Fn2 (fn_sig, returnType, f) ->
          (* We check the type equality between the function signature and the parameters type *)
          begin match eq_sig_typ signature fn_sig with Eq ->
            inject returnType (
              f (get_value_content p1) (get_value_content p2)
            )
          end
      | _ -> raise Not_found
      end
    with Not_found ->
      print_error name signature;
      raise Not_found
  end

end

let register1: type a b. string -> a typ -> b typ -> (a -> b) -> unit = begin
  fun name typ1 returnType f ->
    let signature = C.T1(typ1) in
    C.register name (C.Sig signature) (C.Fn1 (typ1, returnType, f))
end

let register2: type a b c. string -> (a typ * b typ) -> c typ -> ( a -> b -> c) -> unit = begin
  fun name (typ1, typ2) result f ->
    let signature = C.T2(typ1, typ2) in
    C.register name (C.Sig signature) (C.Fn2 (signature, result, f))
end

(* Register the standard functions *)

let () = begin

  register2 "="  (t_int, t_int) t_bool (=);
  register2 "<>" (t_int, t_int) t_bool (<>);

  register2 "+"  (t_int, t_int) t_int (+);
  register2 "*"  (t_int, t_int) t_int ( * );
  register2 "/"  (t_int, t_int) t_int (/);
  register2 "-"  (t_int, t_int) t_int (-);

  register2 "="  (t_bool, t_bool) t_bool (=);
  register2 "<>" (t_bool, t_bool) t_bool (<>);

  register1 "not" t_bool           t_bool not;
  register2 "and"(t_bool, t_bool)  t_bool (&&);
  register2 "or" (t_bool, t_bool)  t_bool (||);


  let i2 = inject t_int 2
  and i3 = inject t_int 3
  and b1 = inject t_bool true in
  let r1 = C.eval2 "=" i2 i3 in
  let r2 = C.eval1 "not" r1 in
  let r3 = C.eval2 "=" b1 r2 in
  let Result value = r3 in
  match value with
  | Bool b -> Printf.printf "%b\n" b
  | Int  n -> Printf.printf "%d\n" n

end