From e82962fe44c35b5ae6e6a68e8719e5d77aaf9e55 Mon Sep 17 00:00:00 2001 From: Sébastien Dailly Date: Mon, 6 Nov 2017 17:39:53 +0100 Subject: Simplify type deduction --- evaluator.ml | 110 ++++++++++++++++------------------------------------------- scTypes.ml | 83 ++++++++++++++++++++++++-------------------- scTypes.mli | 8 +++++ 3 files changed, 83 insertions(+), 118 deletions(-) diff --git a/evaluator.ml b/evaluator.ml index 8862b3a..a5f3380 100755 --- a/evaluator.ml +++ b/evaluator.ml @@ -50,24 +50,12 @@ fun printer typ -> match typ with type 'a returnType = 'a ScTypes.returnType -let specialize_result: type a. a ScTypes.returnType -> a ScTypes.dataFormat -> a ScTypes.returnType = - begin fun a b -> match a, b with - | ScTypes.Num (Some ScTypes.Date) as _1, _ -> _1 - | _, ScTypes.Date -> ScTypes.Num (Some ScTypes.Date) - | x, y -> x -end - -let typ_of_result: type a. a ScTypes.returnType -> a typ = function - | ScTypes.Num _ -> Num - | ScTypes.Bool -> Bool - | ScTypes.Str -> String - (*** Values definitions *) type 'a value = | Bool: D.Bool.t -> D.Bool.t value | Num: D.Num.t ScTypes.dataFormat * D.Num.t -> D.Num.t value - | String: UTF8.t -> UTF8.t value + | String: UTF8.t -> UTF8.t value | List: 'a ScTypes.dataFormat * 'a list -> 'a list value | Matrix: 'a ScTypes.dataFormat * 'a list list -> 'a list list value @@ -87,45 +75,6 @@ let type_of_value: type a. a value -> a typ = function | List (t, l) -> List (typ_of_format t) | Matrix (t, l) -> List (List (typ_of_format t)) -let inject': - type a. a ScTypes.returnType -> (unit -> a ScTypes.dataFormat) -> a -> a value = - fun resultFormat f res -> begin match resultFormat, res with - | ScTypes.Bool, x -> Bool x - | ScTypes.Str, s -> String s - | ScTypes.Num None, x -> Num (f (), x) - | ScTypes.Num (Some v), x -> Num(v, x) - end - -let compare_format: type a b. a typ -> a ScTypes.returnType -> b value -> a ScTypes.returnType = begin -fun init_typ currentResult value -> - - (* If the argument as the same type as the result format, just select the most specialized *) - match compare_typ init_typ (type_of_value value) with - | T.Eq -> begin match value with - | Bool b -> ScTypes.Bool - | String s -> ScTypes.Str - | Num (f, v) -> specialize_result currentResult f - (* There is no possibility to get init_typ as List typ *) - | List (f, v) -> raise Errors.TypeError - | Matrix (f, v) -> raise Errors.TypeError - end - (* The types differ, handle the special cases for Lists *) - | _ -> - begin match value with - | List (f, v) -> - begin match compare_typ init_typ (typ_of_format f) with - | T.Eq -> specialize_result currentResult f - | _ -> currentResult - end - | Matrix (f, v) -> - begin match compare_typ init_typ (typ_of_format f) with - | T.Eq -> specialize_result currentResult f - | _ -> currentResult - end - | _ -> currentResult - end - end - end module C = Catalog.Make(Data) @@ -142,34 +91,29 @@ let repr = C.repr type existencialResult = | Result : 'a Data.value -> existencialResult [@@unboxed] -(** Guess the format to use for the result function from the arguments given. - The most specialized format take over the others. -*) -let guess_format_result: -type a. a ScTypes.returnType -> existencialResult list -> unit -> a Data.dataFormat = -begin fun init_value values () -> - - let init_typ:a Data.typ = Data.typ_of_result init_value in - - (* fold over the arguments, and check if they have the same format *) - let compare_format: a ScTypes.returnType -> existencialResult -> a ScTypes.returnType = - fun currentResult (Result value) -> - Data.compare_format init_typ currentResult value in - - begin match List.fold_left compare_format init_value values with - | ScTypes.Str -> ScTypes.String - | ScTypes.Bool -> ScTypes.Bool - | ScTypes.Num None-> ScTypes.Number - | ScTypes.Num (Some x)-> x +let inject: +type a. a Data.dataFormat -> a -> existencialResult = fun resultFormat res -> + begin match resultFormat with + | ScTypes.Bool -> Result (Data.Bool res) + | ScTypes.String -> Result (Data.String res) + | ScTypes.Number -> Result (Data.Num (resultFormat, res)) + | ScTypes.Date -> Result (Data.Num (resultFormat, res)) end -end -let inject: -type a. a Data.returnType -> (unit -> a Data.dataFormat) -> a -> existencialResult = -fun resultFormat f res -> - let (x:a Data.value) = Data.inject' resultFormat f res in - Result x +(** Extract the format from a list of results *) +let build_format_list ll () = + + List.map (fun (Result x) -> + begin match x with + | Data.Bool _ -> ScTypes.DataFormat.F (ScTypes.Bool) + | Data.Num (x, _) -> ScTypes.DataFormat.F x + | Data.String _ -> ScTypes.DataFormat.F (ScTypes.String) + | Data.List (f, _) -> ScTypes.DataFormat.F f + | Data.Matrix (f, _) -> ScTypes.DataFormat.F f + end + ) ll + let register0 name returnType f = catalog := C.register !catalog name (C.T1(Data.Unit)) (C.Fn1 (returnType, f)) @@ -188,22 +132,26 @@ let call name args = begin begin try match args with | [] -> let C.Fn1(ret, f) = C.find_function !catalog name' (C.T1 Data.Unit) in - inject ret (fun () -> raise Errors.TypeError) (f ()) + let returnType = ScTypes.DataFormat.guess_format_result ret (fun () -> raise Errors.TypeError) in + inject returnType (f ()) | (Result p1)::[] -> let C.Fn1(ret, f) = C.find_function !catalog name' (C.T1 (Data.type_of_value p1)) in - inject ret (guess_format_result ret args) (f (Data.get_value_content p1)) + let returnType = ScTypes.DataFormat.guess_format_result ret (build_format_list args) in + inject returnType (f (Data.get_value_content p1)) | (Result p1)::(Result p2)::[] -> let C.Fn2(ret, f) = C.find_function !catalog name' (C.T2 (Data.type_of_value p1, Data.type_of_value p2)) in - inject ret (guess_format_result ret args) (f (Data.get_value_content p1) (Data.get_value_content p2)) + let returnType = ScTypes.DataFormat.guess_format_result ret (build_format_list args) in + inject returnType (f (Data.get_value_content p1) (Data.get_value_content p2)) | (Result p1)::(Result p2)::(Result p3)::[] -> let C.Fn3(ret, f) = C.find_function !catalog name' (C.T3 (Data.type_of_value p1, Data.type_of_value p2, Data.type_of_value p3)) in - inject ret (guess_format_result ret args) (f (Data.get_value_content p1) (Data.get_value_content p2) (Data.get_value_content p3)) + let returnType = ScTypes.DataFormat.guess_format_result ret (build_format_list args) in + inject returnType (f (Data.get_value_content p1) (Data.get_value_content p2) (Data.get_value_content p3)) | _ -> raise Not_found with Not_found -> diff --git a/scTypes.ml b/scTypes.ml index 81af61c..ca2b32f 100755 --- a/scTypes.ml +++ b/scTypes.ml @@ -22,12 +22,6 @@ let get_numeric_type: DataType.Num.t dataFormat -> numericType = function | Date -> Date | Number -> Number -let priority: type a. a dataFormat -> int = function - | Date -> 1 - | Number -> 0 - | String -> 0 - | Bool -> 0 - type 'a types = | Num : DataType.Num.t dataFormat * DataType.Num.t -> DataType.Num.t types (** A number *) | Str : DataType.String.t -> DataType.String.t types (** A string *) @@ -66,6 +60,49 @@ type result = | Result : 'a types -> result | Error : exn -> result +module DataFormat = struct + + type formats = F : 'a dataFormat -> formats [@@unboxed] + + let priority: type a. a dataFormat -> int = function + | Date -> 1 + | Number -> 0 + | String -> 0 + | Bool -> 0 + + let collect_format: DataType.Num.t dataFormat -> formats -> DataType.Num.t dataFormat = begin + fun dataFormat -> function + | F Date -> Date + | _ -> dataFormat + end + + let guess_format_result: type a. a returnType -> (unit -> formats list) -> a dataFormat = + fun return params -> begin match return with + | Str -> String + | Bool -> Bool + | Num (Some x) -> x + | Num None -> List.fold_left collect_format Number (params ()) + end + + let default_value_for: type a. a dataFormat -> a = function + | Date -> DataType.Num.nan + | Number -> DataType.Num.nan + | Bool -> false + | String -> UTF8.empty + + let compare_format: type a b. a dataFormat -> b dataFormat -> (a, b) equality = + fun a b -> begin match a, b with + | Date, Date -> Eq + | String, String -> Eq + | Number, Number -> Eq + | Date, Number -> Eq + | Number, Date -> Eq + | Bool, Bool -> Eq + | _, _ -> raise Errors.TypeError + end + +end + module Type = struct (* Required because Num.Big_int cannot be compared with Pervasives.(=) *) let (=) : type a b. a types -> b types -> bool = fun t1 t2 -> @@ -113,23 +150,6 @@ module Type = struct | Bool b -> Value (Bool, b) end - (* - let guess_format_result: - type a. a returnType -> t list -> (a -> a types) = - fun return params -> begin match return with - | Str -> fun value -> Str value - | Bool -> fun value -> Bool value - | Num (Some x) -> fun value -> Num (x, value) - | Num None -> fun value -> Num (Number, value) - end - *) - - let default_value_for: type a. a dataFormat -> a = function - | Date -> DataType.Num.nan - | Number -> DataType.Num.nan - | Bool -> false - | String -> UTF8.empty - end module Refs = struct @@ -200,17 +220,6 @@ module Refs = struct type ('a, 'b) equality = Eq : ('a, 'a) equality - let compare_format: type a b. a dataFormat -> b dataFormat -> (a, b) equality = - fun a b -> begin match a, b with - | Date, Date -> Eq - | String, String -> Eq - | Number, Number -> Eq - | Date, Number -> Eq - | Number, Date -> Eq - | Bool, Bool -> Eq - | _, _ -> raise Errors.TypeError - end - (** Add one element in a typed list. The function will raise Error.TypeError if the elements does not match @@ -219,12 +228,12 @@ module Refs = struct let add_elem: type a b. a dataFormat * a list -> result option -> a dataFormat * a list = fun (format, elements) result -> begin match result with - | None -> format, (Type.default_value_for format)::elements + | None -> format, (DataFormat.default_value_for format)::elements | Some (Error x) -> raise x | Some (Result r) -> let Type.Value (format', element) = Type.get_content r in - let Eq = compare_format format format' in - let new_format = if (priority format) < (priority format') then + let Eq = DataFormat.compare_format format format' in + let new_format = if (DataFormat.priority format) > (DataFormat.priority format') then format else format' in diff --git a/scTypes.mli b/scTypes.mli index ad0d0ee..d147d92 100755 --- a/scTypes.mli +++ b/scTypes.mli @@ -65,6 +65,14 @@ type result = | Result : 'a types -> result | Error : exn -> result +module DataFormat : sig + + type formats = F : 'a dataFormat -> formats [@@unboxed] + + val guess_format_result: 'a returnType -> (unit -> formats list) -> 'a dataFormat + +end + module Type : sig type t = Value: 'a dataFormat * 'a -> t -- cgit v1.2.3