diff --git a/CHANGES.md b/CHANGES.md index 4a35f0218..66b25e19f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +### v4.5.0 (2020-04-23) + +* client: add timeout for DNS requests (defaults to 5 seconds, as in resolv.h). +* dns-client-mirage functor requires a Mirage_time.S implementation (changes API). + Update your code as in this commit: + https://github.com/roburio/unikernels/commit/201e980f458ebb515298392227294e7b508a1009 + #223 @linse @hannesm, review by @cfcs + ### v4.4.1 (2020-03-29) * client: treat '*.localhost' and '*.invalid' special, as specified in RFC 6761 diff --git a/app/odns.ml b/app/odns.ml index 90eac775a..ea6883085 100644 --- a/app/odns.ml +++ b/app/odns.ml @@ -28,9 +28,13 @@ let pp_zone_tlsa ppf (domain,ttl,(tlsa:Dns.Tlsa.t)) = | n -> loop ((String.sub hex n 56)::acc) (n+56) in loop [] 0) -let do_a nameserver domains _ = - let clock = Mtime_clock.elapsed_ns in - let t = Dns_client_lwt.create ?nameserver ~clock () in +let ns ip is_udp = match ip with + | None -> None + | Some ip -> if is_udp then Some (`UDP, ip) else Some (`TCP, ip) + +let do_a nameserver is_udp domains _ = + let nameserver = ns nameserver is_udp in + let t = Dns_client_lwt.create ?nameserver () in let (_, (ns_ip, _)) = Dns_client_lwt.nameserver t in Logs.info (fun m -> m "querying NS %s for A records of %a" (Unix.string_of_inet_addr ns_ip) @@ -56,12 +60,12 @@ let do_a nameserver domains _ = match Lwt_main.run job with | () -> Ok () (* TODO handle errors *) -let for_all_domains nameserver ~domains typ f = +let for_all_domains nameserver is_udp ~domains typ f = (* [for_all_domains] is a utility function that lets us avoid duplicating this block of code in all the subcommands. We leave {!do_a} simple to provide a more readable example. *) - let clock = Mtime_clock.elapsed_ns in - let t = Dns_client_lwt.create ?nameserver ~clock () in + let nameserver = ns nameserver is_udp in + let t = Dns_client_lwt.create ?nameserver () in let _, (ns_ip, _) = Dns_client_lwt.nameserver t in Logs.info (fun m -> m "NS: %s" @@ Unix.string_of_inet_addr ns_ip); let open Lwt in @@ -84,16 +88,16 @@ let pp_response typ domain = function | Error _ -> () | Ok resp -> Logs.app (fun m -> m "%a" pp_zone (domain, typ, resp)) -let do_aaaa nameserver domains _ = - for_all_domains nameserver ~domains Dns.Rr_map.Aaaa +let do_aaaa nameserver is_udp domains _ = + for_all_domains nameserver is_udp ~domains Dns.Rr_map.Aaaa (pp_response Dns.Rr_map.Aaaa) -let do_mx nameserver domains _ = - for_all_domains nameserver ~domains Dns.Rr_map.Mx +let do_mx nameserver is_udp domains _ = + for_all_domains nameserver is_udp ~domains Dns.Rr_map.Mx (pp_response Dns.Rr_map.Mx) -let do_tlsa nameserver domains _ = - for_all_domains nameserver ~domains Dns.Rr_map.Tlsa +let do_tlsa nameserver is_udp domains _ = + for_all_domains nameserver is_udp ~domains Dns.Rr_map.Tlsa (fun domain -> function | Ok (ttl, tlsa_resp) -> Dns.Rr_map.Tlsa_set.iter (fun tlsa -> @@ -102,8 +106,8 @@ let do_tlsa nameserver domains _ = | Error _ -> () ) -let do_txt nameserver domains _ = - for_all_domains nameserver ~domains Dns.Rr_map.Txt +let do_txt nameserver is_udp domains _ = + for_all_domains nameserver is_udp ~domains Dns.Rr_map.Txt (fun _domain -> function | Ok (ttl, txtset) -> Dns.Rr_map.Txt_set.iter (fun txtrr -> @@ -112,17 +116,17 @@ let do_txt nameserver domains _ = | Error _ -> () ) -let do_any _nameserver _domains _ = +let do_any _nameserver _is_udp _domains _ = (* TODO *) Error (`Msg "ANY functionality is not present atm due to refactorings, come back later") -let do_dkim nameserver (selector:string) domains _ = +let do_dkim nameserver is_udp (selector:string) domains _ = let domains = List.map (fun original_domain -> Domain_name.prepend_label_exn (Domain_name.prepend_label_exn (original_domain) "_domainkey") selector ) domains in - for_all_domains nameserver ~domains Dns.Rr_map.Txt + for_all_domains nameserver is_udp ~domains Dns.Rr_map.Txt (fun _domain -> function | Ok (_ttl, txtset) -> Dns.Rr_map.Txt_set.iter (fun txt -> @@ -144,18 +148,25 @@ let setup_log = Term.(const _setup_log $ Fmt_cli.style_renderer ~docs:sdocs () $ Logs_cli.level ~docs:sdocs ()) -let parse_ns : ('a * (Lwt_unix.inet_addr * int)) Arg.conv = +let parse_ns : (Lwt_unix.inet_addr * int) Arg.conv = ( fun ns -> - try `Ok (`TCP, (Unix.inet_addr_of_string ns, 53)) with + try match String.split_on_char ':' ns with + | [ ns ] -> `Ok (Unix.inet_addr_of_string ns, 53) + | [ ns ; port ] -> `Ok (Unix.inet_addr_of_string ns, int_of_string port) + | _ -> `Error "bad name server" + with | _ -> `Error "NS must be an IPv4 address"), - ( fun ppf (typ, (ns, port)) -> - Fmt.pf ppf "%s:%d(%s)" (Unix.string_of_inet_addr ns) port - (match typ with `UDP -> "udp" | `TCP -> "tcp")) + ( fun ppf (ns, port) -> + Fmt.pf ppf "%s:%d" (Unix.string_of_inet_addr ns) port ) let arg_ns : 'a Term.t = let doc = "IP of nameserver to use" in Arg.(value & opt (some parse_ns) None & info ~docv:"NS-IP" ~doc ["ns"]) +let arg_udp = + let doc = "Connect via UDP to resolver" in + Arg.(value & flag & info [ "udp" ] ~doc) + let parse_domain : [ `raw ] Domain_name.t Arg.conv = ( fun name -> Domain_name.of_string name @@ -179,7 +190,7 @@ let cmd_a : unit Term.t * Term.info = let man = [ `P {| Output mimics that of $(b,dig A )$(i,DOMAIN)|} ] in - Term.(term_result (const do_a $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_a $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "a" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_aaaa : unit Term.t * Term.info = @@ -187,7 +198,7 @@ let cmd_aaaa : unit Term.t * Term.info = let man = [ `P {| Output mimics that of $(b,dig AAAA )$(i,DOMAIN)|} ] in - Term.(term_result (const do_aaaa $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_aaaa $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "aaaa" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_mx : unit Term.t * Term.info = @@ -195,7 +206,7 @@ let cmd_mx : unit Term.t * Term.info = let man = [ `P {| Output mimics that of $(b,dig MX )$(i,DOMAIN)|} ] in - Term.(term_result (const do_mx $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_mx $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "mx" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_tlsa : unit Term.t * Term.info = @@ -217,7 +228,7 @@ let cmd_tlsa : unit Term.t * Term.info = `P {| $(b,_993._tcp) (IMAP) |} ; `S Manpage.s_options ; ] in - Term.(term_result (const do_tlsa $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_tlsa $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "tlsa" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_txt : unit Term.t * Term.info = @@ -229,7 +240,7 @@ let cmd_txt : unit Term.t * Term.info = It would be nice to mirror `dig` output here.|} ; `S Manpage.s_options ; ] in - Term.(term_result (const do_txt $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_txt $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "txt" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_any : unit Term.t * Term.info = @@ -240,7 +251,7 @@ let cmd_any : unit Term.t * Term.info = `P {| The output will be fairly similar to $(b,dig ANY )$(i,example.com)|} ; `S Manpage.s_options ; ] in - Term.(term_result (const do_any $ arg_ns $ arg_domains $ setup_log)), + Term.(term_result (const do_any $ arg_ns $ arg_udp $ arg_domains $ setup_log)), Term.info "any" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs let cmd_dkim : unit Term.t * Term.info = @@ -256,7 +267,7 @@ let cmd_dkim : unit Term.t * Term.info = |} ; `S Manpage.s_options ; ] in - Term.(term_result (const do_dkim $ arg_ns $ arg_selector + Term.(term_result (const do_dkim $ arg_ns $ arg_udp $ arg_selector $ arg_domains $ setup_log)), Term.info "dkim" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs diff --git a/client/dns_client.ml b/client/dns_client.ml index 088c480be..786fc6f79 100644 --- a/client/dns_client.ml +++ b/client/dns_client.ml @@ -135,33 +135,27 @@ module Pure = struct end - -let stdlib_random n = - let b = Cstruct.create n in - for i = 0 to pred n do - Cstruct.set_uint8 b i (Random.int 256) - done; - b - +(* Anycast address of uncensoreddns.org *) let default_resolver = "91.239.100.100" module type S = sig - type flow + type context type +'a io type io_addr type ns_addr = ([`TCP | `UDP]) * io_addr type stack type t - val create : ?rng:(int -> Cstruct.t) -> ?nameserver:ns_addr -> stack -> t + val create : ?nameserver:ns_addr -> timeout:int64 -> stack -> t val nameserver : t -> ns_addr - val rng : t -> (int -> Cstruct.t) + val rng : int -> Cstruct.t + val clock : unit -> int64 - val connect : ?nameserver:ns_addr -> t -> (flow, [> `Msg of string ]) result io - val send : flow -> Cstruct.t -> (unit, [> `Msg of string ]) result io - val recv : flow -> (Cstruct.t, [> `Msg of string ]) result io - val close : flow -> unit io + val connect : ?nameserver:ns_addr -> t -> (context, [> `Msg of string ]) result io + val send : context -> Cstruct.t -> (unit, [> `Msg of string ]) result io + val recv : context -> (Cstruct.t, [> `Msg of string ]) result io + val close : context -> unit io val bind : 'a io -> ('a -> 'b io) -> 'b io val lift : 'a -> 'a io @@ -190,14 +184,12 @@ struct type t = { cache : Dns_cache.t ; - clock : unit -> int64 ; transport : Transport.t ; } - let create ?(size=32) ?rng ?nameserver ~clock stack = + let create ?(size=32) ?nameserver ?(timeout = Duration.of_sec 5) stack = { cache = Dns_cache.empty size ; - clock = clock ; - transport = Transport.create ?rng ?nameserver stack + transport = Transport.create ?nameserver ~timeout stack } let nameserver { transport; _ } = Transport.nameserver transport @@ -245,14 +237,14 @@ struct | Ok _ as ok -> Transport.lift ok | Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift | Error `Msg _ -> - match Dns_cache.get t.cache (t.clock ()) domain_name query_type |> lift_ok query_type with + match Dns_cache.get t.cache (Transport.clock ()) domain_name query_type |> lift_ok query_type with | Ok _ as ok -> Transport.lift ok | Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift | Error `Msg _ -> let proto, _ = match nameserver with | None -> Transport.nameserver t.transport | Some x -> x in let tx, state = - Pure.make_query (Transport.rng t.transport) + Pure.make_query Transport.rng (match proto with `UDP -> `Udp | `TCP -> `Tcp) name query_type in Transport.connect ?nameserver t.transport >>| fun socket -> @@ -261,7 +253,7 @@ struct Logs.debug (fun m -> m "Receiving from NS"); let update_cache entry = let rank = Dns_cache.NonAuthoritativeAnswer in - Dns_cache.set t.cache (t.clock ()) domain_name query_type rank entry + Dns_cache.set t.cache (Transport.clock ()) domain_name query_type rank entry in let rec recv_loop acc = Transport.recv socket >>| fun recv_buffer -> diff --git a/client/dns_client.mli b/client/dns_client.mli index 5cc978115..90dd9c521 100644 --- a/client/dns_client.mli +++ b/client/dns_client.mli @@ -4,35 +4,28 @@ better solution presents itself. *) -val stdlib_random : int -> Cstruct.t -(** [stdlib_random len] is a buffer of size [len], filled with random data. - This function is used by default (in the Unix and Lwt implementations) for - filling the ID field of the DNS packet. Internally, the {!Random} module - from the OCaml standard library is used, which is not cryptographically - secure. If desired {!Nocrypto.Rng.generate} can be passed to {!S.create}. *) - val default_resolver : string (** [default_resolver] is the IPv4 address in dotted-decimal form of the default resolver. Currently it is the IP address of the UncensoredDNS.org anycast service. *) module type S = sig - type flow - (** A flow is a network connection initialized by {!T.connect} *) + type context + (** A context is a network connection initialized by {!T.connect} *) type +'a io (** [io] is the type of an effect. ['err] is a polymorphic variant. *) type io_addr - (** An address for a given flow type, usually this will consist of - IP address + a TCP/IP or UDP/IP port number, but for some flow types + (** An address for a given context type, usually this will consist of + IP address + a TCP/IP or UDP/IP port number, but for some context types it can carry additional information for purposes of cryptographic verification. TODO at least that would be nice in the future. TODO *) type ns_addr = [ `TCP | `UDP] * io_addr (** TODO well this is kind of crude; it's a tuple to prevent having - to do endless amounts of currying things when implementing flow types, + to do endless amounts of currying things when implementing context types, and we need to know the protocol used so we can prefix packets for DNS-over-TCP and set correct socket options etc. therefore we can't just use the opaque [io_addr]. @@ -44,28 +37,33 @@ module type S = sig type t (** The abstract state of a DNS client. *) - val create : ?rng:(int -> Cstruct.t) -> ?nameserver:ns_addr -> stack -> t - (** [create ~rng ~nameserver stack] creates the state record of the DNS client. *) + val create : ?nameserver:ns_addr -> timeout:int64 -> stack -> t + (** [create ~nameserver ~timeout stack] creates the state record of the DNS + client. We use [timeout] (ns) as a cumulative time budget for connect + and request timeouts. *) val nameserver : t -> ns_addr (** The address of a nameserver that is supposed to work with - the underlying flow, can be used if the user does not want to + the underlying context, can be used if the user does not want to bother with configuring their own.*) - val rng : t -> (int -> Cstruct.t) + val rng : int -> Cstruct.t (** [rng t] is a random number generator. *) - val connect : ?nameserver:ns_addr -> t -> (flow, [> `Msg of string ]) result io - (** [connect addr] is a new connection ([flow]) to [addr], or an error. *) + val clock : unit -> int64 + (** [clock t] is the monotonic clock. *) + + val connect : ?nameserver:ns_addr -> t -> (context, [> `Msg of string ]) result io + (** [connect addr] is a new connection ([context]) to [addr], or an error. *) - val send : flow -> Cstruct.t -> (unit, [> `Msg of string ]) result io - (** [send flow buffer] sends [buffer] to the [flow] upstream.*) + val send : context -> Cstruct.t -> (unit, [> `Msg of string ]) result io + (** [send context buffer] sends [buffer] to the [context] upstream.*) - val recv : flow -> (Cstruct.t, [> `Msg of string ]) result io - (** [recv flow] tries to read a [buffer] from the [flow] downstream.*) + val recv : context -> (Cstruct.t, [> `Msg of string ]) result io + (** [recv context] tries to read a [buffer] from the [context] downstream.*) - val close : flow -> unit io - (** [close flow] closes the [flow], freeing up resources. *) + val close : context -> unit io + (** [close context] closes the [context], freeing up resources. *) val bind : 'a io -> ('a -> 'b io) -> 'b io (** a.k.a. [>>=] *) @@ -78,8 +76,11 @@ sig type t - val create : ?size:int -> ?rng:(int -> Cstruct.t) -> ?nameserver:T.ns_addr -> clock:(unit -> int64) -> T.stack -> t - (** [create ~size ~rng ~nameserver ~clock stack] creates the state of the DNS client. *) + val create : ?size:int -> ?nameserver:T.ns_addr -> ?timeout:int64 -> T.stack -> t + (** [create ~size ~nameserver ~timeout stack] creates the state of the DNS client. + We use [timeout] (ns, default 3s) as a time budget for connect and request timeouts. + To specify a timeout, use [create ~timeout:(Duration.of_sec 5)]. + *) val nameserver : t -> T.ns_addr (** [nameserver state] returns the default nameserver to be used. *) diff --git a/dns-client.opam b/dns-client.opam index 89f0412f3..110e3e16d 100644 --- a/dns-client.opam +++ b/dns-client.opam @@ -25,8 +25,10 @@ depends: [ "lwt" {>= "4.2.1"} "mirage-stack" {>= "2.0.0"} "mirage-random" {>= "2.0.0"} + "mirage-time" {>= "2.0.0"} "mirage-clock" {>= "3.0.0"} "mtime" {>= "1.2.0"} + "mirage-crypto-rng" ] synopsis: "Pure DNS resolver API" description: """ diff --git a/lwt/client/dns_client_lwt.ml b/lwt/client/dns_client_lwt.ml index 4fb4b9447..c916cc300 100644 --- a/lwt/client/dns_client_lwt.ml +++ b/lwt/client/dns_client_lwt.ml @@ -1,58 +1,64 @@ -(* {!Transport} provides the implementation of the underlying flow - that is in turn used by {!Dns_client.Make} to provide the - Lwt convenience module -*) - open Lwt.Infix module Transport : Dns_client.S - with type flow = Lwt_unix.file_descr - and type io_addr = Lwt_unix.inet_addr * int + with type io_addr = Lwt_unix.inet_addr * int and type +'a io = 'a Lwt.t and type stack = unit = struct type io_addr = Lwt_unix.inet_addr * int - type flow = Lwt_unix.file_descr type ns_addr = [`TCP | `UDP] * io_addr type +'a io = 'a Lwt.t type stack = unit type t = { - rng : (int -> Cstruct.t) ; nameserver : ns_addr ; + timeout_ns : int64 ; } + type context = { t : t ; fd : Lwt_unix.file_descr ; timeout_ns : int64 ref } let create - ?(rng = Dns_client.stdlib_random) - ?(nameserver = `TCP, (Unix.inet_addr_of_string Dns_client.default_resolver, 53)) () = - { rng ; nameserver } + ?(nameserver = `TCP, (Unix.inet_addr_of_string Dns_client.default_resolver, 53)) + ~timeout () = + Mirage_crypto_rng_unix.initialize (); + { nameserver ; timeout_ns = timeout } let nameserver { nameserver ; _ } = nameserver - let rng { rng ; _ } = rng + let rng = Mirage_crypto_rng.generate ?g:None + let clock = Mtime_clock.elapsed_ns + + let with_timeout ctx f = + let timeout = Lwt_unix.sleep (Duration.to_f !(ctx.timeout_ns)) >|= fun () -> Error (`Msg "DNS request timeout") in + let start = clock () in + Lwt.pick [ f ; timeout ] >|= fun result -> + let stop = clock () in + ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start); + result - let close socket = - Lwt.catch (fun () -> Lwt_unix.close socket) (fun _ -> Lwt.return_unit) + let close { fd ; _ } = + Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) - let send socket tx = + let send ctx tx = let open Lwt in Lwt.catch (fun () -> - Lwt_unix.send socket (Cstruct.to_bytes tx) 0 + with_timeout ctx + (Lwt_unix.send ctx.fd (Cstruct.to_bytes tx) 0 (Cstruct.len tx) [] >>= fun res -> if res <> Cstruct.len tx then Lwt_result.fail (`Msg ("oops" ^ (string_of_int res))) else - Lwt_result.return ()) + Lwt_result.return ())) (fun e -> Lwt.return (Error (`Msg (Printexc.to_string e)))) - let recv socket = + let recv ctx = let open Lwt in let recv_buffer = Bytes.make 2048 '\000' in Lwt.catch (fun () -> - Lwt_unix.recv socket recv_buffer 0 (Bytes.length recv_buffer) [] - >>= fun read_len -> - if read_len > 0 then - Lwt_result.return (Cstruct.of_bytes ~len:read_len recv_buffer) - else - Lwt_result.fail (`Msg "Empty response")) + with_timeout ctx + (Lwt_unix.recv ctx.fd recv_buffer 0 (Bytes.length recv_buffer) [] + >>= fun read_len -> + if read_len > 0 then + Lwt_result.return (Cstruct.of_bytes ~len:read_len recv_buffer) + else + Lwt_result.fail (`Msg "Empty response"))) (fun e -> Lwt_result.fail (`Msg (Printexc.to_string e))) let bind = Lwt.bind @@ -73,11 +79,16 @@ module Transport : Dns_client.S end >>= fun (proto_number, socket_type) -> let socket = Lwt_unix.socket PF_INET socket_type proto_number in let addr = Lwt_unix.ADDR_INET (server, port) in + let ctx = { t ; fd = socket ; timeout_ns = ref t.timeout_ns } in Lwt.catch (fun () -> - Lwt_unix.connect socket addr >|= fun () -> - Ok socket) + (* SO_RCVTIMEO does not work in Lwt: it results in an EAGAIN, which + is handled by re-queuing the event *) + with_timeout ctx + (Lwt_unix.connect socket addr >|= fun () -> Ok ()) >>= function + | Ok () -> Lwt_result.return ctx + | Error e -> close ctx >|= fun () -> Error e) (fun e -> - close socket >|= fun () -> + close ctx >|= fun () -> Error (`Msg (Printexc.to_string e)))) (fun e -> Lwt_result.fail (`Msg (Printexc.to_string e))) diff --git a/lwt/client/dns_client_lwt.mli b/lwt/client/dns_client_lwt.mli index af2951239..0d9544562 100644 --- a/lwt/client/dns_client_lwt.mli +++ b/lwt/client/dns_client_lwt.mli @@ -1,13 +1,18 @@ (** {!Lwt_unix} helper module for {!Dns_client}. For more information see the {!Dns_client.Make} functor. + + The {!Dns_client} is available as Dns_client_lwt after + linking to dns-client.lwt in your dune file. + + The [create] function embeds the side effect of initializing the RNG + by calling {!Mirage_crypto_rng_unix.initialize}. *) (** A flow module based on non-blocking I/O on top of the Lwt_unix socket API. *) module Transport : Dns_client.S - with type flow = Lwt_unix.file_descr - and type io_addr = Lwt_unix.inet_addr * int + with type io_addr = Lwt_unix.inet_addr * int and type +'a io = 'a Lwt.t and type stack = unit diff --git a/lwt/client/dune b/lwt/client/dune index 3d0a4b5fa..94bc01158 100644 --- a/lwt/client/dune +++ b/lwt/client/dune @@ -2,5 +2,5 @@ (name dns_client_lwt) (modules dns_client_lwt) (public_name dns-client.lwt) - (libraries lwt lwt.unix dns dns-client) + (libraries lwt lwt.unix dns dns-client mtime.clock.os mirage-crypto-rng.unix) (wrapped false)) diff --git a/mirage/client/dns_client_mirage.ml b/mirage/client/dns_client_mirage.ml index b5ae79a3d..87536b1cc 100644 --- a/mirage/client/dns_client_mirage.ml +++ b/mirage/client/dns_client_mirage.ml @@ -3,64 +3,69 @@ open Lwt.Infix let src = Logs.Src.create "dns_client_mirage" ~doc:"effectful DNS client layer" module Log = (val Logs.src_log src : Logs.LOG) -module Make (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) = struct +module Make (R : Mirage_random.S) (T : Mirage_time.S) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) = struct module Transport : Dns_client.S - with type flow = S.TCPV4.flow - and type stack = S.t + with type stack = S.t and type +'a io = 'a Lwt.t and type io_addr = Ipaddr.V4.t * int = struct - type flow = S.TCPV4.flow type stack = S.t type io_addr = Ipaddr.V4.t * int type ns_addr = [`TCP | `UDP] * io_addr type +'a io = 'a Lwt.t type t = { - rng : (int -> Cstruct.t) ; nameserver : ns_addr ; + timeout_ns : int64 ; stack : stack ; } + type context = { t : t ; flow : S.TCPV4.flow ; timeout_ns : int64 ref } let create - ?rng ?(nameserver = `TCP, (Ipaddr.V4.of_string_exn Dns_client.default_resolver, 53)) + ~timeout stack = - let rng = match rng with None -> R.generate ?g:None | Some x -> x in - { rng ; nameserver ; stack } + { nameserver ; timeout_ns = timeout ; stack } let nameserver { nameserver ; _ } = nameserver - let rng { rng ; _ } = rng + let rng = R.generate ?g:None + let clock = C.elapsed_ns + + let with_timeout time_left f = + let timeout = T.sleep_ns !time_left >|= fun () -> Error (`Msg "DNS request timeout") in + let start = clock () in + Lwt.pick [ f ; timeout ] >|= fun result -> + let stop = clock () in + time_left := Int64.sub !time_left (Int64.sub stop start); + result let bind = Lwt.bind let lift = Lwt.return let connect ?nameserver:ns t = let _proto, addr = match ns with None -> nameserver t | Some x -> x in - S.TCPV4.create_connection (S.tcpv4 t.stack) addr >|= function + let time_left = ref t.timeout_ns in + with_timeout time_left (S.TCPV4.create_connection (S.tcpv4 t.stack) addr >|= function | Error e -> Log.err (fun m -> m "error connecting to nameserver %a" S.TCPV4.pp_error e) ; Error (`Msg "connect failure") - | Ok flow -> Ok flow + | Ok flow -> Ok { t ; flow ; timeout_ns = time_left }) - let close f = S.TCPV4.close f + let close { flow ; _ } = S.TCPV4.close flow - let recv flow = - S.TCPV4.read flow >|= function + let recv ctx = + with_timeout ctx.timeout_ns (S.TCPV4.read ctx.flow >|= function | Error e -> Error (`Msg (Fmt.to_to_string S.TCPV4.pp_error e)) | Ok (`Data cs) -> Ok cs - | Ok `Eof -> Ok Cstruct.empty + | Ok `Eof -> Ok Cstruct.empty) - let send flow s = - S.TCPV4.write flow s >|= function + let send ctx s = + with_timeout ctx.timeout_ns (S.TCPV4.write ctx.flow s >|= function | Error e -> Error (`Msg (Fmt.to_to_string S.TCPV4.pp_write_error e)) - | Ok () -> Ok () + | Ok () -> Ok ()) end include Dns_client.Make(Transport) - - let create ?size ?nameserver stack = - create ?size ~rng:R.generate ?nameserver ~clock:C.elapsed_ns stack end (* diff --git a/mirage/client/dns_client_mirage.mli b/mirage/client/dns_client_mirage.mli index 2acfdeb62..1a2535b6d 100644 --- a/mirage/client/dns_client_mirage.mli +++ b/mirage/client/dns_client_mirage.mli @@ -1,17 +1,16 @@ -module Make (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) : sig +module Make (R : Mirage_random.S) (T : Mirage_time.S) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) : sig module Transport : Dns_client.S - with type flow = S.TCPV4.flow - and type io_addr = Ipaddr.V4.t * int + with type io_addr = Ipaddr.V4.t * int and type +'a io = 'a Lwt.t and type stack = S.t include module type of Dns_client.Make(Transport) - val create : ?size:int -> ?nameserver:Transport.ns_addr -> S.t -> t + val create : ?size:int -> ?nameserver:Transport.ns_addr -> ?timeout:int64 -> S.t -> t (** [create ~size ~nameserver stack] uses [R.generate] and [C.elapsed_ns] as random number generator and timestamp source, and calls the generic - [Dns_client.Make.create]. *) + {!Dns_client.Make.create}. *) end (* diff --git a/mirage/client/dune b/mirage/client/dune index 4c5ae076b..7466444bc 100644 --- a/mirage/client/dune +++ b/mirage/client/dune @@ -1,5 +1,5 @@ (library (name dns_client_mirage) (public_name dns-client.mirage) - (libraries domain-name ipaddr mirage-random mirage-stack mirage-clock dns-client) + (libraries domain-name ipaddr mirage-random mirage-time mirage-stack mirage-clock dns-client) (wrapped false)) diff --git a/mirage/stub/dns_stub_mirage.ml b/mirage/stub/dns_stub_mirage.ml index 664cfb548..8902f3ff4 100644 --- a/mirage/stub/dns_stub_mirage.ml +++ b/mirage/stub/dns_stub_mirage.ml @@ -6,7 +6,7 @@ open Dns let src = Logs.Src.create "dns_stub_mirage" ~doc:"effectful DNS stub layer" module Log = (val Logs.src_log src : Logs.LOG) -module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) = struct +module Make (R : Mirage_random.S) (T : Mirage_time.S) (P : Mirage_clock.PCLOCK) (C : Mirage_clock.MCLOCK) (S : Mirage_stack.V4) = struct (* data in the wild: - a request comes in hdr, q @@ -77,23 +77,31 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (C : Mirage_clock.MC type +'a io = 'a Lwt.t type t = { - rng : (int -> Cstruct.t) ; nameserver : ns_addr ; + timeout_ns : int64 ; stack : stack ; mutable flow : S.TCPV4.flow option ; mutable requests : (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t IM.t ; } - type flow = { t : t ; mutable id : int } + type context = { t : t ; timeout_ns : int64 ref ; mutable id : int } let create - ?rng ?(nameserver = `TCP, (Ipaddr.V4.of_string_exn Dns_client.default_resolver, 53)) + ~timeout stack = - let rng = match rng with None -> R.generate ?g:None | Some x -> x in - { rng ; nameserver ; stack ; flow = None ; requests = IM.empty } + { nameserver ; timeout_ns = timeout ; stack ; flow = None ; requests = IM.empty } let nameserver { nameserver ; _ } = nameserver - let rng { rng ; _ } = rng + let rng = R.generate ?g:None + let clock = C.elapsed_ns + + let with_timeout time_left f = + let timeout = T.sleep_ns !time_left >|= fun () -> Error (`Msg "DNS request timeout") in + let start = clock () in + Lwt.pick [ f ; timeout ] >|= fun result -> + let stop = clock () in + time_left := Int64.sub !time_left (Int64.sub stop start); + result let bind = Lwt.bind let lift = Lwt.return @@ -142,10 +150,11 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (C : Mirage_clock.MC let connect ?nameserver:ns t = match t.flow with - | Some _ -> Lwt.return (Ok ({ t ; id = 0 })) + | Some _ -> Lwt.return (Ok ({ t ; timeout_ns = ref t.timeout_ns ; id = 0 })) | None -> let _proto, addr = match ns with None -> nameserver t | Some x -> x in - S.TCPV4.create_connection (S.tcpv4 t.stack) addr >|= function + let time_left = ref t.timeout_ns in + with_timeout time_left (S.TCPV4.create_connection (S.tcpv4 t.stack) addr >|= function | Error e -> Log.err (fun m -> m "error connecting to nameserver %a" S.TCPV4.pp_error e) ; @@ -154,40 +163,37 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (C : Mirage_clock.MC metrics `Recursive_connections; Lwt.async (fun () -> read_loop t flow); t.flow <- Some flow; - Ok ({ t ; id = 0 }) + Ok ({ t ; timeout_ns = time_left ; id = 0 })) let close _f = (* ignoring this here *) Lwt.return_unit - let recv { t ; id } = + let recv { t ; timeout_ns ; id } = let cond = Lwt_condition.create () in t.requests <- IM.add id cond t.requests; - Lwt_condition.wait cond >|= fun data -> + with_timeout timeout_ns (Lwt_condition.wait cond) >|= fun data -> t.requests <- IM.remove id t.requests; match data with | Ok cs -> Ok cs | Error `Msg m -> Error (`Msg m) - let send f s = - match f.t.flow with + let send ({ t ; timeout_ns ; _ } as f) s = + match t.flow with | None -> Lwt.return (Error (`Msg "no connection to resolver")) | Some flow -> let id = Cstruct.BE.get_uint16 s 2 in f.id <- id; - S.TCPV4.write flow s >>= function + with_timeout timeout_ns (S.TCPV4.write flow s >>= function | Error e -> - f.t.flow <- None; + t.flow <- None; Lwt.return (Error (`Msg (Fmt.to_to_string S.TCPV4.pp_write_error e))) | Ok () -> metrics `Recursive_queries; - Lwt.return (Ok ()) + Lwt.return (Ok ())) end include Dns_client.Make(Transport) - - let create ?size ?nameserver stack = - create ?size ~rng:R.generate ?nameserver ~clock:C.elapsed_ns stack end (* likely this should contain: diff --git a/test/client.ml b/test/client.ml index 9f975973f..24738b5fc 100644 --- a/test/client.ml +++ b/test/client.ml @@ -13,7 +13,7 @@ let p_cs = Alcotest.testable Cstruct.hexdump_pp Cstruct.equal module Make_query_tests = struct let produces_same_output () = - let rng = fun x -> Cstruct.create(x) in + let rng = Cstruct.create in let name:'a Domain_name.t = Domain_name.of_string_exn "example.com" in let actual, _state = Dns_client.Pure.make_query rng `Tcp name Dns.Rr_map.A in let expected = Cstruct.of_hex @@ -41,7 +41,7 @@ module Parse_response_tests = struct 02 40 8a 00 04 17 15 f3 77" in (* This `rng` generates zeros, used for the query ID above *) - let rng = fun x -> Cstruct.create(x) in + let rng = Cstruct.create in let name:'a Domain_name.t = Domain_name.of_string_exn "foo.com" in let _actual, state = Dns_client.Pure.make_query rng `Tcp name Dns.Rr_map.A in match Dns_client.Pure.parse_response state ipv4_buf with @@ -63,7 +63,7 @@ module Parse_response_tests = struct 02 40 8a 00 04 17 15 f3 77" in (* This `rng` generates zeros, used for the query ID above *) - let rng = fun x -> Cstruct.create(x) in + let rng = Cstruct.create in let name:'a Domain_name.t = Domain_name.of_string_exn "foo.com" in let _actual, state = Dns_client.Pure.make_query rng `Tcp name Dns.Rr_map.A in match Dns_client.Pure.parse_response state ipv4_buf with @@ -78,7 +78,7 @@ module Parse_response_tests = struct end (* {!Transport} provides a mock implementation of the transport used by - Dns_client.Make. The mock data is passed as type flow and io_addr in + Dns_client.Make. The mock data is passed as type context and io_addr in connect/recv/send by supplying the optional ?nameserver argument. *) @@ -86,26 +86,25 @@ type debug_info = Cstruct.t list ref let default_debug_info = ref [] -module Transport : Dns_client.S - with type flow = debug_info - and type io_addr = debug_info +module Transport (*: Dns_client.S + with type io_addr = debug_info and type stack = unit - and type +'a io = 'a + and type +'a io = 'a *) = struct type io_addr = debug_info type ns_addr = [`TCP | `UDP] * io_addr type stack = unit - type flow = debug_info - type t = int -> Cstruct.t + type context = debug_info + type t = unit type +'a io = 'a let create - ?(rng = Cstruct.create) - ?nameserver:_ () = - rng + ?nameserver:_ ~timeout:_ () = + () let nameserver _ = `TCP, default_debug_info - let rng x = x + let rng = Cstruct.create + let clock () = 0L let bind a b = b a let lift v = v @@ -117,12 +116,12 @@ module Transport : Dns_client.S let connect ?nameserver:ns _ = match ns with | None -> Ok default_debug_info - | Some(_, mock_responses) -> Ok mock_responses + | Some (_, mock_responses) -> Ok mock_responses let send _ _ = Ok () - let recv (mock_responses:flow) = + let recv (mock_responses : context) = match !mock_responses with | [] -> failwith("nothing to recv from the wire") | hd::tail -> mock_responses := tail; Ok hd @@ -132,6 +131,19 @@ end that goes on top of it: *) include Dns_client.Make(Transport) +module Transport_with_time_machine = struct + include Transport + + (* the timestamps are for: cache lookup1, cache upate, cache lookup2, cache update2 *) + let timestamps = ref [0L; 0L; Duration.of_sec 601 ; Duration.of_sec 601] + let clock () = + match !timestamps with + | [] -> assert false + | head::tail -> timestamps := tail; head +end + +module Dns_client_with_time_machine = Dns_client.Make(Transport_with_time_machine) + module Gethostbyname_tests = struct let foo_com_is_valid () = let domain_name = Domain_name.(of_string_exn "foo.com" |> host_exn) in @@ -145,8 +157,7 @@ module Gethostbyname_tests = struct ae 00 06 03 6e 73 32 c0 39 c0 35 00 01 00 01 00 02 40 8a 00 04 17 15 f2 58 c0 51 00 01 00 01 00 02 40 8a 00 04 17 15 f3 77" in - let clock () = 0L in - let t = create ~clock () in + let t = create () in let ns = `TCP, ref [ipv4_buf] in match gethostbyname t domain_name ~nameserver:ns with | Ok _ip -> () @@ -164,8 +175,7 @@ module Gethostbyname_tests = struct ae 00 06 03 6e 73 32 c0 39 c0 35 00 01 00 01 00 02 40 8a 00 04 17 15 f2 58 c0 51 00 01 00 01 00 02 40 8a 00 04 17 15 f3 77" in - let clock () = 0L in - let t = create ~clock () in + let t = create () in let ns = `TCP, ref [ipv4_buf] in match gethostbyname t domain_name ~nameserver:ns with | Error _ -> failwith "foo.com should have been returned" @@ -187,20 +197,13 @@ module Gethostbyname_tests = struct ae 00 06 03 6e 73 32 c0 39 c0 35 00 01 00 01 00 02 40 8a 00 04 17 15 f2 58 c0 51 00 01 00 01 00 02 40 8a 00 04 17 15 f3 77" in - (* the timestamps are for: cache lookup1, cache upate, cache lookup2, cache update2 *) - let timestamps = ref [0L; 0L; Duration.of_sec 601 ; Duration.of_sec 601] in - let time_machine () = - match !timestamps with - | [] -> assert false - | head::tail -> timestamps := tail; head - in - let t = create ~clock:time_machine () in + let t = Dns_client_with_time_machine.create () in let ns = `TCP, ref [ipv4_buf] in - match gethostbyname t domain_name ~nameserver:ns with + match Dns_client_with_time_machine.gethostbyname t domain_name ~nameserver:ns with | Error _ -> failwith "foo.com should have been returned" | Ok _ip -> let mock_ns_responses = `TCP, ref [ipv4_buf] in - match gethostbyname t domain_name ~nameserver:mock_ns_responses with + match Dns_client_with_time_machine.gethostbyname t domain_name ~nameserver:mock_ns_responses with | Error _ -> failwith "should have been cached" | Ok _ -> (* we returned content, AND the wire was used *) assert (!(snd mock_ns_responses) = []) @@ -252,8 +255,7 @@ module Getaddrinfo_tests = struct 00 00 00 0a c0 a6 00 1c 00 01 00 01 d1 ba 00 10 20 01 48 60 48 02 00 38 00 00 00 00 00 00 00 0a" in - let clock () = 0L in - let mock_state = create ~clock () in + let mock_state = create () in let ns = `TCP, ref [ipv4_buf] in match getaddrinfo mock_state Dns.Rr_map.Mx domain_name ~nameserver:ns with | Ok (_ttl, mx_set) -> @@ -283,8 +285,7 @@ module Getaddrinfo_tests = struct let udp_buf = Cstruct.of_hex " 00 00 81 80 00 01 00 05 00 04 00 0f 06 67 6f 6f 67 6c 65 03 63 6f " in - let clock () = 0L in - let mock_state = create ~clock () in + let mock_state = create () in let ns = `UDP, ref [udp_buf] in match getaddrinfo mock_state Dns.Rr_map.Mx domain_name ~nameserver:ns with | Error `Msg actual -> @@ -312,8 +313,7 @@ c0 42 0a 68 6f 73 74 6d 61 73 74 65 72 06 66 61 73 74 6c 79 c0 22 78 39 c6 29 00 00 0e 10 00 00 02 58 00 09 3a 80 00 00 00 1e|} in - let clock () = 0L in - let mock_state = create ~clock () in + let mock_state = create () in let ns = `UDP, ref [udp_buf] in match getaddrinfo mock_state Dns.Rr_map.Aaaa domain_name ~nameserver:ns with | Error `Msg actual -> diff --git a/unix/client/dns_client_unix.ml b/unix/client/dns_client_unix.ml index 8fc28edf8..c541c436e 100644 --- a/unix/client/dns_client_unix.ml +++ b/unix/client/dns_client_unix.ml @@ -4,36 +4,48 @@ *) module Transport : Dns_client.S - with type flow = Unix.file_descr - and type io_addr = Unix.inet_addr * int + with type io_addr = Unix.inet_addr * int and type stack = unit and type +'a io = 'a = struct type io_addr = Unix.inet_addr * int type ns_addr = [`TCP | `UDP] * io_addr type stack = unit - type flow = Unix.file_descr type t = { - rng : int -> Cstruct.t ; nameserver : ns_addr ; + timeout_ns : int64 ; } + type context = { t : t ; fd : Unix.file_descr ; timeout_ns : int64 ref } type +'a io = 'a let create - ?(rng = Dns_client.stdlib_random) - ?(nameserver = `TCP, (Unix.inet_addr_of_string Dns_client.default_resolver, 53)) () = - { rng ; nameserver } + ?(nameserver = `TCP, (Unix.inet_addr_of_string Dns_client.default_resolver, 53)) ~timeout () = + Mirage_crypto_rng_unix.initialize (); + { nameserver ; timeout_ns = timeout } let nameserver { nameserver ; _ } = nameserver - let rng { rng ; _ } = rng + let clock = Mtime_clock.elapsed_ns + let rng = Mirage_crypto_rng.generate ?g:None + + open Rresult let bind a b = b a let lift v = v - open Rresult + let close { fd ; _ } = try Unix.close fd with _ -> () - let close socket = try Unix.close socket with _ -> () + let with_timeout ctx f = + let start = clock () in + (* TODO cancel execution of f when time_left is 0 *) + let r = f ctx.fd in + let stop = clock () in + ctx.timeout_ns := Int64.sub !(ctx.timeout_ns) (Int64.sub stop start); + if !(ctx.timeout_ns) <= 0L then + Error (`Msg "DNS resolution timed out.") + else + r + (* there is no connect timeouts, just a request timeout (unix: receive timeout) *) let connect ?nameserver:ns t = let proto, (server, port) = match ns with None -> nameserver t | Some x -> x @@ -44,35 +56,42 @@ module Transport : Dns_client.S | `TCP -> Ok Unix.((getprotobyname "tcp").p_proto) end >>= fun proto_number -> let socket = Unix.socket PF_INET SOCK_STREAM proto_number in + let time_left = ref t.timeout_ns in let addr = Unix.ADDR_INET (server, port) in + let ctx = { t ; fd = socket ; timeout_ns = time_left } in try - Unix.connect socket addr ; - Ok socket + with_timeout ctx (fun fd -> + Unix.connect fd addr; + Ok ctx) with e -> - close socket ; + close ctx; Error (`Msg (Printexc.to_string e)) with e -> Error (`Msg (Printexc.to_string e)) - let send (socket:flow) (tx:Cstruct.t) = + let send ctx (tx : Cstruct.t) = let str = Cstruct.to_string tx in try - let res = Unix.send_substring socket str 0 (String.length str) [] in - if res <> String.length str - then - Error (`Msg ("Broken write to upstream NS" ^ (string_of_int res))) - else Ok () + with_timeout ctx (fun fd -> + Unix.setsockopt_float fd Unix.SO_SNDTIMEO (Duration.to_f !(ctx.timeout_ns)); + let res = Unix.send_substring fd str 0 (String.length str) [] in + if res <> String.length str + then + Error (`Msg ("Broken write to upstream NS" ^ (string_of_int res))) + else Ok ()) with e -> Error (`Msg (Printexc.to_string e)) - let recv (socket:flow) = + let recv ctx = let buffer = Bytes.make 2048 '\000' in try - let x = Unix.recv socket buffer 0 (Bytes.length buffer) [] in - if x > 0 && x <= Bytes.length buffer then - Ok (Cstruct.of_bytes buffer ~len:x) - else - Error (`Msg "Reading from NS socket failed") + with_timeout ctx (fun fd -> + Unix.setsockopt_float fd Unix.SO_RCVTIMEO (Duration.to_f !(ctx.timeout_ns)); + let x = Unix.recv fd buffer 0 (Bytes.length buffer) [] in + if x > 0 && x <= Bytes.length buffer then + Ok (Cstruct.of_bytes buffer ~len:x) + else + Error (`Msg "Reading from NS socket failed")) with e -> Error (`Msg (Printexc.to_string e)) end diff --git a/unix/client/dns_client_unix.mli b/unix/client/dns_client_unix.mli index 492e75876..7e1552ff0 100644 --- a/unix/client/dns_client_unix.mli +++ b/unix/client/dns_client_unix.mli @@ -3,10 +3,12 @@ *) -(** A flow module based on blocking I/O on top of the Unix socket API. *) +(** A flow module based on blocking I/O on top of the Unix socket API. + + TODO: Implement the connect timeout. +*) module Transport : Dns_client.S - with type flow = Unix.file_descr - and type io_addr = Unix.inet_addr * int + with type io_addr = Unix.inet_addr * int and type stack = unit and type +'a io = 'a diff --git a/unix/client/dune b/unix/client/dune index 724b7e1f6..d1c479e31 100644 --- a/unix/client/dune +++ b/unix/client/dune @@ -2,7 +2,7 @@ (name dns_client_unix) (modules dns_client_unix) (public_name dns-client.unix) - (libraries domain-name ipaddr dns-client rresult unix) + (libraries domain-name ipaddr dns-client rresult unix mtime.clock.os mirage-crypto-rng.unix) (wrapped false)) (executable diff --git a/unix/client/ohost.ml b/unix/client/ohost.ml index 351ada44e..3963e5c31 100644 --- a/unix/client/ohost.ml +++ b/unix/client/ohost.ml @@ -1,6 +1,5 @@ let () = - let clock = Mtime_clock.elapsed_ns in - let t = Dns_client_unix.create ~clock () in + let t = Dns_client_unix.create () in let domain = Domain_name.(host_exn (of_string_exn Sys.argv.(1))) in let ipv4 = match Dns_client_unix.gethostbyname t domain with