aboutsummaryrefslogtreecommitdiff
path: root/src/catalog.ml
blob: 95f13cec62ddf7b0b1f50f6f51ad04123b97e938 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
module T = Tools
module type DATA_SIG = sig

  type 'a t

  type 'a returnType

  val compare_typ: 'a t -> 'b t -> ('a, 'b) T.cmp

  val repr: Format.formatter -> 'a t -> unit

end

module type CATALOG = sig

  type 'a argument
  type 'a returnType

  type t

  (** Create a new catalog builder used for registering all the functions *)
  type catalog_builder

  (** Empty catalog *)
  val empty: catalog_builder

  val register1:
    string ->                     (* The function name *)
    'a argument ->                (* The signature *)
    'b returnType ->              (* The return type *)
    ('a -> 'b) ->                 (* The function to call *)
    catalog_builder -> catalog_builder

  val register2:
    string ->                     (* The function name *)
    ('a argument * 'b argument) ->(* The signature *)
    'c returnType ->              (* The return type *)
    ( 'a -> 'b -> 'c) ->          (* The function to call*)
    catalog_builder -> catalog_builder

  val register3:
    string ->                     (* The function name *)
    ('a argument * 'b argument * 'c argument) -> (* The signature *)
    'd returnType ->              (* The return type *)
    ( 'a -> 'b -> 'c -> 'd) ->    (* The function to call*)
    catalog_builder -> catalog_builder


  (** Compile the catalog *)
  val compile: catalog_builder -> t


  type result =
    | R : 'a returnType * 'a -> result

  val eval1: t -> string -> ('a argument * 'a) -> result

  val eval2: t -> string -> ('a argument * 'a) -> ('b argument * 'b) -> result

  val eval3: t -> string -> ('a argument * 'a) -> ('b argument * 'b) -> ('c argument * 'c) -> result

end

(** We cannot update an existing function. Any [registerX] function will raise
 [RegisteredFunction] if a function with the same signature already exists in
the catalog. *)
exception RegisteredFunction

(** Catalog for all functions *)
module Make(Data:DATA_SIG) = struct

  type 'a argument = 'a Data.t
  type 'a returnType = 'a Data.returnType

  (** This is the way the function is store in the map.
       We just the return type, and the function itself. *)
  type _ t_function =
    | Fn1: 'b Data.returnType * ('a -> 'b) -> 'a t_function
    | Fn2: 'c Data.returnType * ('a -> 'b -> 'c) -> ('a * 'b) t_function
    | Fn3: 'd Data.returnType * ('a -> 'b -> 'c -> 'd) -> ('a * 'b * 'c) t_function

  (** This is the key for storing functions in the map. *)
  type _ sig_typ =
    | T1: 'a Data.t -> 'a t_function sig_typ
    | T2: 'a Data.t * 'b Data.t -> ('a * 'b) t_function sig_typ
    | T3: 'a Data.t * 'b Data.t * 'c Data.t -> ('a * 'b * 'c) t_function sig_typ


  module ComparableSignature = struct

    type 'a t = 'a sig_typ

    (* Type for pure equality *)
    type (_, _) eq = Eq : ('a, 'a) eq

    (** Compare two signature *)
    let comp: type a b. a sig_typ -> b sig_typ -> (a, b) T.cmp = begin fun a b ->

      let cmp: type c d. c Data.t -> d Data.t -> ((c, d) eq -> (a, b) T.cmp) -> (a, b) T.cmp =
      begin fun a b f -> match Data.compare_typ a b with
        | T.Eq -> f Eq
        | T.Lt -> T.Lt
        | T.Gt -> T.Gt
      end in

      match a, b with
      | T1(a), T1(b) -> cmp a b (fun Eq -> T.Eq)
      | T1(_), _ -> T.Lt
      | _, T1(_) -> T.Gt

      | T2(a, b), T2(c, d) ->
        cmp a c (fun Eq ->
          cmp b d (fun Eq -> T.Eq)
        )
      | T2(_), _ -> T.Lt
      | _, T2(_) -> T.Gt

      | T3(a, b, c), T3(d, e, f) ->
        cmp a d (fun Eq ->
          cmp b e (fun Eq ->
            cmp c f (fun Eq -> T.Eq)
          )
        )
      end

      let repr : type a. Format.formatter -> a t -> unit = begin fun formatter -> function
        | T1 t -> Format.fprintf formatter "(%a)" Data.repr t
        | T2 (t1, t2) -> Format.fprintf formatter "(%a,%a)" Data.repr t1 Data.repr t2
        | T3 (t1, t2, t3) -> Format.fprintf formatter "(%a,%a,%a)" Data.repr t1 Data.repr t2 Data.repr t3
      end

  end

  module Catalog = Map.Make(String)
  module Functions = Splay.Make(ComparableSignature)

  (* This is the map which contains all the registered functions.
     Each name is binded with another map with contains the function for each
     signature.
   *)
  type t = Functions.t Base.String_dict.t
  type catalog_builder = Functions.t Catalog.t

  let empty = Catalog.empty

  (**
     Register a function in the catalog. If the function is already defined,
     raise an exception.
   *)
  let register t name signature f = begin

    let name' = String.uppercase_ascii name in
    let map = begin match Catalog.find name' t with
    | exception Not_found ->
        Functions.add signature f Functions.empty
    | x ->
      if Functions.member signature x then
        raise RegisteredFunction
      else
        Functions.add signature f x
    end in
    Catalog.add name' map t
  end

  let register1 name typ1 returnType f catalog =
    register catalog name (T1(typ1)) (Fn1 (returnType, f))

  let register2 name (typ1, typ2) result f catalog =
    register catalog name (T2(typ1, typ2)) (Fn2 (result, f))

  let register3 name (typ1, typ2, typ3) result f catalog =
    register catalog name (T3(typ1, typ2, typ3)) (Fn3 (result, f))

  (** Look in the catalog for a function with the given name and signature *)
  let find_function:
  type a. t -> string -> a t_function sig_typ -> a t_function =
  begin fun t name signature ->
       Base.String_dict.find_exn t (String.uppercase_ascii name)
    |> Functions.find signature
  end

  let compile t =
    (* Use efficient Base.String_dict.
       The requirement to have a unique key is garantee by the Map structure.
    *)
    Base.String_dict.of_alist_exn (Catalog.bindings t)


  type result =
    | R : 'a returnType * 'a -> result

  let eval1 catalog name (t1, arg1) = begin
    let Fn1(ret, f) = find_function catalog name (T1 t1) in
    R (ret, f arg1)
  end

  let eval2 catalog name (t1, arg1) (t2, arg2) = begin
    let Fn2(ret, f) = find_function catalog name (T2 (t1, t2)) in
    R (ret, f arg1 arg2)
  end

  let eval3 catalog name (t1, arg1) (t2, arg2) (t3, arg3) = begin
    let Fn3(ret, f) = find_function catalog name (T3 (t1, t2, t3)) in
    R (ret, f arg1 arg2 arg3)
  end
end