Skip to content

Commit

Permalink
Merge pull request #223 from mirage/timeout
Browse files Browse the repository at this point in the history
add timeout for dns-client
  • Loading branch information
hannesm authored Apr 23, 2020
2 parents 4254778 + ef55ab6 commit 6ed77c8
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 222 deletions.
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
### v4.5.0 (2020-04-23)

* client: add timeout for DNS requests (defaults to 5 seconds, as in resolv.h).
* dns-client-mirage functor requires a Mirage_time.S implementation (changes API).
Update your code as in this commit:
https://github.com/roburio/unikernels/commit/201e980f458ebb515298392227294e7b508a1009
#223 @linse @hannesm, review by @cfcs

### v4.4.1 (2020-03-29)

* client: treat '*.localhost' and '*.invalid' special, as specified in RFC 6761
Expand Down
69 changes: 40 additions & 29 deletions app/odns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ let pp_zone_tlsa ppf (domain,ttl,(tlsa:Dns.Tlsa.t)) =
| n -> loop ((String.sub hex n 56)::acc) (n+56)
in loop [] 0)

let do_a nameserver domains _ =
let clock = Mtime_clock.elapsed_ns in
let t = Dns_client_lwt.create ?nameserver ~clock () in
let ns ip is_udp = match ip with
| None -> None
| Some ip -> if is_udp then Some (`UDP, ip) else Some (`TCP, ip)

let do_a nameserver is_udp domains _ =
let nameserver = ns nameserver is_udp in
let t = Dns_client_lwt.create ?nameserver () in
let (_, (ns_ip, _)) = Dns_client_lwt.nameserver t in
Logs.info (fun m -> m "querying NS %s for A records of %a"
(Unix.string_of_inet_addr ns_ip)
Expand All @@ -56,12 +60,12 @@ let do_a nameserver domains _ =
match Lwt_main.run job with
| () -> Ok () (* TODO handle errors *)

let for_all_domains nameserver ~domains typ f =
let for_all_domains nameserver is_udp ~domains typ f =
(* [for_all_domains] is a utility function that lets us avoid duplicating
this block of code in all the subcommands.
We leave {!do_a} simple to provide a more readable example. *)
let clock = Mtime_clock.elapsed_ns in
let t = Dns_client_lwt.create ?nameserver ~clock () in
let nameserver = ns nameserver is_udp in
let t = Dns_client_lwt.create ?nameserver () in
let _, (ns_ip, _) = Dns_client_lwt.nameserver t in
Logs.info (fun m -> m "NS: %s" @@ Unix.string_of_inet_addr ns_ip);
let open Lwt in
Expand All @@ -84,16 +88,16 @@ let pp_response typ domain = function
| Error _ -> ()
| Ok resp -> Logs.app (fun m -> m "%a" pp_zone (domain, typ, resp))

let do_aaaa nameserver domains _ =
for_all_domains nameserver ~domains Dns.Rr_map.Aaaa
let do_aaaa nameserver is_udp domains _ =
for_all_domains nameserver is_udp ~domains Dns.Rr_map.Aaaa
(pp_response Dns.Rr_map.Aaaa)

let do_mx nameserver domains _ =
for_all_domains nameserver ~domains Dns.Rr_map.Mx
let do_mx nameserver is_udp domains _ =
for_all_domains nameserver is_udp ~domains Dns.Rr_map.Mx
(pp_response Dns.Rr_map.Mx)

let do_tlsa nameserver domains _ =
for_all_domains nameserver ~domains Dns.Rr_map.Tlsa
let do_tlsa nameserver is_udp domains _ =
for_all_domains nameserver is_udp ~domains Dns.Rr_map.Tlsa
(fun domain -> function
| Ok (ttl, tlsa_resp) ->
Dns.Rr_map.Tlsa_set.iter (fun tlsa ->
Expand All @@ -102,8 +106,8 @@ let do_tlsa nameserver domains _ =
| Error _ -> () )


let do_txt nameserver domains _ =
for_all_domains nameserver ~domains Dns.Rr_map.Txt
let do_txt nameserver is_udp domains _ =
for_all_domains nameserver is_udp ~domains Dns.Rr_map.Txt
(fun _domain -> function
| Ok (ttl, txtset) ->
Dns.Rr_map.Txt_set.iter (fun txtrr ->
Expand All @@ -112,17 +116,17 @@ let do_txt nameserver domains _ =
| Error _ -> () )


let do_any _nameserver _domains _ =
let do_any _nameserver _is_udp _domains _ =
(* TODO *)
Error (`Msg "ANY functionality is not present atm due to refactorings, come back later")

let do_dkim nameserver (selector:string) domains _ =
let do_dkim nameserver is_udp (selector:string) domains _ =
let domains = List.map (fun original_domain ->
Domain_name.prepend_label_exn
(Domain_name.prepend_label_exn
(original_domain) "_domainkey") selector
) domains in
for_all_domains nameserver ~domains Dns.Rr_map.Txt
for_all_domains nameserver is_udp ~domains Dns.Rr_map.Txt
(fun _domain -> function
| Ok (_ttl, txtset) ->
Dns.Rr_map.Txt_set.iter (fun txt ->
Expand All @@ -144,18 +148,25 @@ let setup_log =
Term.(const _setup_log $ Fmt_cli.style_renderer ~docs:sdocs ()
$ Logs_cli.level ~docs:sdocs ())

let parse_ns : ('a * (Lwt_unix.inet_addr * int)) Arg.conv =
let parse_ns : (Lwt_unix.inet_addr * int) Arg.conv =
( fun ns ->
try `Ok (`TCP, (Unix.inet_addr_of_string ns, 53)) with
try match String.split_on_char ':' ns with
| [ ns ] -> `Ok (Unix.inet_addr_of_string ns, 53)
| [ ns ; port ] -> `Ok (Unix.inet_addr_of_string ns, int_of_string port)
| _ -> `Error "bad name server"
with
| _ -> `Error "NS must be an IPv4 address"),
( fun ppf (typ, (ns, port)) ->
Fmt.pf ppf "%s:%d(%s)" (Unix.string_of_inet_addr ns) port
(match typ with `UDP -> "udp" | `TCP -> "tcp"))
( fun ppf (ns, port) ->
Fmt.pf ppf "%s:%d" (Unix.string_of_inet_addr ns) port )

let arg_ns : 'a Term.t =
let doc = "IP of nameserver to use" in
Arg.(value & opt (some parse_ns) None & info ~docv:"NS-IP" ~doc ["ns"])

let arg_udp =
let doc = "Connect via UDP to resolver" in
Arg.(value & flag & info [ "udp" ] ~doc)

let parse_domain : [ `raw ] Domain_name.t Arg.conv =
( fun name ->
Domain_name.of_string name
Expand All @@ -179,23 +190,23 @@ let cmd_a : unit Term.t * Term.info =
let man = [
`P {| Output mimics that of $(b,dig A )$(i,DOMAIN)|}
] in
Term.(term_result (const do_a $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_a $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "a" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_aaaa : unit Term.t * Term.info =
let doc = "Query a NS for AAAA records" in
let man = [
`P {| Output mimics that of $(b,dig AAAA )$(i,DOMAIN)|}
] in
Term.(term_result (const do_aaaa $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_aaaa $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "aaaa" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_mx : unit Term.t * Term.info =
let doc = "Query a NS for mailserver (MX) records" in
let man = [
`P {| Output mimics that of $(b,dig MX )$(i,DOMAIN)|}
] in
Term.(term_result (const do_mx $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_mx $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "mx" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_tlsa : unit Term.t * Term.info =
Expand All @@ -217,7 +228,7 @@ let cmd_tlsa : unit Term.t * Term.info =
`P {| $(b,_993._tcp) (IMAP) |} ;
`S Manpage.s_options ;
] in
Term.(term_result (const do_tlsa $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_tlsa $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "tlsa" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_txt : unit Term.t * Term.info =
Expand All @@ -229,7 +240,7 @@ let cmd_txt : unit Term.t * Term.info =
It would be nice to mirror `dig` output here.|} ;
`S Manpage.s_options ;
] in
Term.(term_result (const do_txt $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_txt $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "txt" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_any : unit Term.t * Term.info =
Expand All @@ -240,7 +251,7 @@ let cmd_any : unit Term.t * Term.info =
`P {| The output will be fairly similar to $(b,dig ANY )$(i,example.com)|} ;
`S Manpage.s_options ;
] in
Term.(term_result (const do_any $ arg_ns $ arg_domains $ setup_log)),
Term.(term_result (const do_any $ arg_ns $ arg_udp $ arg_domains $ setup_log)),
Term.info "any" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

let cmd_dkim : unit Term.t * Term.info =
Expand All @@ -256,7 +267,7 @@ let cmd_dkim : unit Term.t * Term.info =
|} ;
`S Manpage.s_options ;
] in
Term.(term_result (const do_dkim $ arg_ns $ arg_selector
Term.(term_result (const do_dkim $ arg_ns $ arg_udp $ arg_selector
$ arg_domains $ setup_log)),
Term.info "dkim" ~version:(Manpage.escape "%%VERSION%%") ~man ~doc ~sdocs

Expand Down
36 changes: 14 additions & 22 deletions client/dns_client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,33 +135,27 @@ module Pure = struct

end


let stdlib_random n =
let b = Cstruct.create n in
for i = 0 to pred n do
Cstruct.set_uint8 b i (Random.int 256)
done;
b

(* Anycast address of uncensoreddns.org *)
let default_resolver = "91.239.100.100"

module type S = sig
type flow
type context
type +'a io
type io_addr
type ns_addr = ([`TCP | `UDP]) * io_addr
type stack
type t

val create : ?rng:(int -> Cstruct.t) -> ?nameserver:ns_addr -> stack -> t
val create : ?nameserver:ns_addr -> timeout:int64 -> stack -> t

val nameserver : t -> ns_addr
val rng : t -> (int -> Cstruct.t)
val rng : int -> Cstruct.t
val clock : unit -> int64

val connect : ?nameserver:ns_addr -> t -> (flow, [> `Msg of string ]) result io
val send : flow -> Cstruct.t -> (unit, [> `Msg of string ]) result io
val recv : flow -> (Cstruct.t, [> `Msg of string ]) result io
val close : flow -> unit io
val connect : ?nameserver:ns_addr -> t -> (context, [> `Msg of string ]) result io
val send : context -> Cstruct.t -> (unit, [> `Msg of string ]) result io
val recv : context -> (Cstruct.t, [> `Msg of string ]) result io
val close : context -> unit io

val bind : 'a io -> ('a -> 'b io) -> 'b io
val lift : 'a -> 'a io
Expand Down Expand Up @@ -190,14 +184,12 @@ struct

type t = {
cache : Dns_cache.t ;
clock : unit -> int64 ;
transport : Transport.t ;
}

let create ?(size=32) ?rng ?nameserver ~clock stack =
let create ?(size=32) ?nameserver ?(timeout = Duration.of_sec 5) stack =
{ cache = Dns_cache.empty size ;
clock = clock ;
transport = Transport.create ?rng ?nameserver stack
transport = Transport.create ?nameserver ~timeout stack
}

let nameserver { transport; _ } = Transport.nameserver transport
Expand Down Expand Up @@ -245,14 +237,14 @@ struct
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
match Dns_cache.get t.cache (t.clock ()) domain_name query_type |> lift_ok query_type with
match Dns_cache.get t.cache (Transport.clock ()) domain_name query_type |> lift_ok query_type with
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
let proto, _ = match nameserver with
| None -> Transport.nameserver t.transport | Some x -> x in
let tx, state =
Pure.make_query (Transport.rng t.transport)
Pure.make_query Transport.rng
(match proto with `UDP -> `Udp | `TCP -> `Tcp) name query_type
in
Transport.connect ?nameserver t.transport >>| fun socket ->
Expand All @@ -261,7 +253,7 @@ struct
Logs.debug (fun m -> m "Receiving from NS");
let update_cache entry =
let rank = Dns_cache.NonAuthoritativeAnswer in
Dns_cache.set t.cache (t.clock ()) domain_name query_type rank entry
Dns_cache.set t.cache (Transport.clock ()) domain_name query_type rank entry
in
let rec recv_loop acc =
Transport.recv socket >>| fun recv_buffer ->
Expand Down
53 changes: 27 additions & 26 deletions client/dns_client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,28 @@
better solution presents itself.
*)

val stdlib_random : int -> Cstruct.t
(** [stdlib_random len] is a buffer of size [len], filled with random data.
This function is used by default (in the Unix and Lwt implementations) for
filling the ID field of the DNS packet. Internally, the {!Random} module
from the OCaml standard library is used, which is not cryptographically
secure. If desired {!Nocrypto.Rng.generate} can be passed to {!S.create}. *)

val default_resolver : string
(** [default_resolver] is the IPv4 address in dotted-decimal form of the default
resolver. Currently it is the IP address of the UncensoredDNS.org anycast
service. *)

module type S = sig
type flow
(** A flow is a network connection initialized by {!T.connect} *)
type context
(** A context is a network connection initialized by {!T.connect} *)

type +'a io
(** [io] is the type of an effect. ['err] is a polymorphic variant. *)

type io_addr
(** An address for a given flow type, usually this will consist of
IP address + a TCP/IP or UDP/IP port number, but for some flow types
(** An address for a given context type, usually this will consist of
IP address + a TCP/IP or UDP/IP port number, but for some context types
it can carry additional information for purposes of cryptographic
verification. TODO at least that would be nice in the future. TODO
*)

type ns_addr = [ `TCP | `UDP] * io_addr
(** TODO well this is kind of crude; it's a tuple to prevent having
to do endless amounts of currying things when implementing flow types,
to do endless amounts of currying things when implementing context types,
and we need to know the protocol used so we can prefix packets for
DNS-over-TCP and set correct socket options etc. therefore we can't
just use the opaque [io_addr].
Expand All @@ -44,28 +37,33 @@ module type S = sig
type t
(** The abstract state of a DNS client. *)

val create : ?rng:(int -> Cstruct.t) -> ?nameserver:ns_addr -> stack -> t
(** [create ~rng ~nameserver stack] creates the state record of the DNS client. *)
val create : ?nameserver:ns_addr -> timeout:int64 -> stack -> t
(** [create ~nameserver ~timeout stack] creates the state record of the DNS
client. We use [timeout] (ns) as a cumulative time budget for connect
and request timeouts. *)

val nameserver : t -> ns_addr
(** The address of a nameserver that is supposed to work with
the underlying flow, can be used if the user does not want to
the underlying context, can be used if the user does not want to
bother with configuring their own.*)

val rng : t -> (int -> Cstruct.t)
val rng : int -> Cstruct.t
(** [rng t] is a random number generator. *)

val connect : ?nameserver:ns_addr -> t -> (flow, [> `Msg of string ]) result io
(** [connect addr] is a new connection ([flow]) to [addr], or an error. *)
val clock : unit -> int64
(** [clock t] is the monotonic clock. *)

val connect : ?nameserver:ns_addr -> t -> (context, [> `Msg of string ]) result io
(** [connect addr] is a new connection ([context]) to [addr], or an error. *)

val send : flow -> Cstruct.t -> (unit, [> `Msg of string ]) result io
(** [send flow buffer] sends [buffer] to the [flow] upstream.*)
val send : context -> Cstruct.t -> (unit, [> `Msg of string ]) result io
(** [send context buffer] sends [buffer] to the [context] upstream.*)

val recv : flow -> (Cstruct.t, [> `Msg of string ]) result io
(** [recv flow] tries to read a [buffer] from the [flow] downstream.*)
val recv : context -> (Cstruct.t, [> `Msg of string ]) result io
(** [recv context] tries to read a [buffer] from the [context] downstream.*)

val close : flow -> unit io
(** [close flow] closes the [flow], freeing up resources. *)
val close : context -> unit io
(** [close context] closes the [context], freeing up resources. *)

val bind : 'a io -> ('a -> 'b io) -> 'b io
(** a.k.a. [>>=] *)
Expand All @@ -78,8 +76,11 @@ sig

type t

val create : ?size:int -> ?rng:(int -> Cstruct.t) -> ?nameserver:T.ns_addr -> clock:(unit -> int64) -> T.stack -> t
(** [create ~size ~rng ~nameserver ~clock stack] creates the state of the DNS client. *)
val create : ?size:int -> ?nameserver:T.ns_addr -> ?timeout:int64 -> T.stack -> t
(** [create ~size ~nameserver ~timeout stack] creates the state of the DNS client.
We use [timeout] (ns, default 3s) as a time budget for connect and request timeouts.
To specify a timeout, use [create ~timeout:(Duration.of_sec 5)].
*)

val nameserver : t -> T.ns_addr
(** [nameserver state] returns the default nameserver to be used. *)
Expand Down
2 changes: 2 additions & 0 deletions dns-client.opam
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ depends: [
"lwt" {>= "4.2.1"}
"mirage-stack" {>= "2.0.0"}
"mirage-random" {>= "2.0.0"}
"mirage-time" {>= "2.0.0"}
"mirage-clock" {>= "3.0.0"}
"mtime" {>= "1.2.0"}
"mirage-crypto-rng"
]
synopsis: "Pure DNS resolver API"
description: """
Expand Down
Loading

0 comments on commit 6ed77c8

Please sign in to comment.