aboutsummaryrefslogtreecommitdiff
path: root/lib/dream_handler/dream_handler.ml
blob: 0454b32ba668a0f6ed4db58ecf2293d4c240783b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
open Lwt_result.Syntax

(** Extract the content from the body request.

    The module given in argument is the definition of the service. *)
let read_body :
    (module Services.JsonServerHandler with type request = 'request) ->
    Dream.request ->
    ('request, Dream.response) result Lwt.t =
 fun (type request)
     (module S : Services.JsonServerHandler with type request = request)
     request ->
  let%lwt json =
    match S.method_ with
    | GET | HEAD ->
        (* GET and HEAD method doesn’t have any body. We assume here the body
             is typed as Unit *)
        Lwt.return `Null
    | _ ->
        let%lwt body = Dream.body request in
        Yojson.Safe.from_string body |> Lwt.return
  in
  let json_content = Lwt.return @@ S.request_of_yojson json in
  Lwt_result.ok json_content

let create_response :
    (module Services.JsonServerHandler with type response = 'response) ->
    ('response, Dream.response) result Lwt.t ->
    Dream.response Dream.promise =
 fun (type response)
     (module S : Services.JsonServerHandler with type response = response)
     response_content ->
  let response =
    let* response_content = response_content in
    let yojson_content = S.yojson_of_response response_content in
    Yojson.Safe.to_string yojson_content |> Lwt_result.return
  in
  match%lwt response with Ok json -> Dream.json json | Error e -> Lwt.return e

(** Simple handler which read the content and apply the transformations to the
    response. *)
let handle :
    (module Services.JsonServerHandler
       with type placeholders = 'placeholders
        and type request = 'request
        and type response = 'response) ->
    ('placeholders -> 'request -> ('response, string) Lwt_result.t) ->
    'placeholders ->
    Dream.handler =
 fun (type placeholders request response)
     (module S : Services.JsonServerHandler
       with type placeholders = placeholders
        and type response = response
        and type request = request) f args request ->
  let response =
    let* body = read_body (module S) request in
    Lwt_result.map_error
      (fun e -> Dream.response ~status:`Internal_Server_Error e)
      (f args body)
  in
  create_response (module S) response

module MakeChecked (S : Services.JsonServerHandler) = struct
  exception Invalid_method

  (** Derive the handler from the standard one by adding a new field [token] in
      the request *)
  module Service = struct
    include S
    open Ppx_yojson_conv_lib.Yojson_conv.Primitives

    type ('a, 'b) result = ('a, 'b) Result.t = Ok of 'a | Error of 'b
    [@@deriving yojson]

    type request = { content : S.request; token : string }
    [@@deriving of_yojson]
    (** This type add the validation token in the body message *)

    let method_ : (request, response) Services.method_ =
      match S.method_ with
      (* We can’t add the crsf token with thoses methods because they do not
        have body *)
      | GET | HEAD -> raise Invalid_method
      | POST -> POST
      | PUT -> PUT
      | DELETE -> DELETE
      | CONNECT -> CONNECT
      | OPTIONS -> OPTIONS
      | TRACE -> TRACE
      | PATCH -> PATCH
  end

  let check_token :
      Dream.request -> string -> (unit, Dream.response) Lwt_result.t =
   fun request token ->
    match%lwt Dream.verify_csrf_token request token with
    | `Ok -> Lwt.return_ok ()
    | _ -> Lwt_result.fail (Dream.response ~status:`Unauthorized "")

  (** Override the handle function by checking the token validity *)
  let handle :
      (S.placeholders -> S.request -> (S.response, string) Lwt_result.t) ->
      S.placeholders ->
      Dream.handler =
   fun f args request ->
    let response =
      let* content = read_body (module Service) request in

      (* Extract the token from the body and check the validity *)
      let* () = check_token request content.token in

      Lwt_result.map_error
        (fun e -> Dream.response ~status:`Internal_Server_Error e)
        (f args content.content)
    in
    create_response (module Service) response
end

let extract_param request name =
  Dream.param request name |> Dream.from_percent_encoded

let method' : type a b.
    (a, b) Services.method_ -> string -> Dream.handler -> Dream.route = function
  | GET -> Dream.get
  | PUT -> Dream.put
  | POST -> Dream.post
  | DELETE -> Dream.delete
  | HEAD -> Dream.head
  | CONNECT -> Dream.connect
  | OPTIONS -> Dream.options
  | TRACE -> Dream.trace
  | PATCH -> Dream.patch

(** Handle the given URL encoded in the application.

    Use the type system to ensure that the path for the route will use the same
    arguments name as in the extraction, and that this url will match the
    signature for the handler

    [handle] method ?path url handler

    will call the [handler] with the arguments extracted from [url]. If [path]
    is given, the route will be created against [path] instead, which allow to
    use differents url inside a scope. *)
let register :
    ?path:'placeholders Path.t ->
    (module Services.JsonServerHandler with type placeholders = 'placeholders) ->
    ('a -> Dream.handler) ->
    Dream.route =
 fun (type placeholders) ?path
     (module S : Services.JsonServerHandler
       with type placeholders = placeholders) f ->
  let partial_handler =
    (* There is no unification possible for the type p' and p when both are
       available.
       That’s why need to evaluate it now in order to remove any abstraction as
       soon as possible *)
    match path with
    | None -> method' S.method_ Path.(repr' S.path)
    | Some p' -> method' S.method_ Path.(repr' p')
  in

  partial_handler (fun request ->
      let placeholders = Path.unzip S.path (extract_param request) in
      f placeholders request)