diff --git a/g3proxy/doc/configuration/values/dpi.rst b/g3proxy/doc/configuration/values/dpi.rst index d8da81a5f..c8e59c0dc 100644 --- a/g3proxy/doc/configuration/values/dpi.rst +++ b/g3proxy/doc/configuration/values/dpi.rst @@ -67,11 +67,12 @@ The keys ars: protocol inspect policy ----------------------- -**type**: string +**type**: string | map Set what we should do to a specific application protocol. -The possible value are: +One can use the *string* type to define an action for any upstream traffic, regardless of the host, +the possible values for this are: - intercept @@ -89,7 +90,11 @@ The possible value are: Block the traffic. And we will try to send application level error code to the client. -.. versionadded:: 1.9.0 +For more complex setups one can also use the *map* type which +is documented in :ref:`acl rule set ` with the only +difference that the action variants are the strings defined here. + +.. versionadded:: 1.11.0 .. _conf_value_dpi_protocol_inspection: diff --git a/g3proxy/src/audit/handle.rs b/g3proxy/src/audit/handle.rs index 77078de35..c98441371 100644 --- a/g3proxy/src/audit/handle.rs +++ b/g3proxy/src/audit/handle.rs @@ -113,8 +113,8 @@ impl AuditHandle { } #[inline] - pub(crate) fn h2_inspect_policy(&self) -> ProtocolInspectPolicy { - self.auditor_config.h2_inspect_policy + pub(crate) fn h2_inspect_policy(&self) -> &ProtocolInspectPolicy { + &self.auditor_config.h2_inspect_policy } #[inline] @@ -123,13 +123,13 @@ impl AuditHandle { } #[inline] - pub(crate) fn websocket_inspect_policy(&self) -> ProtocolInspectPolicy { - self.auditor_config.websocket_inspect_policy + pub(crate) fn websocket_inspect_policy(&self) -> &ProtocolInspectPolicy { + &self.auditor_config.websocket_inspect_policy } #[inline] - pub(crate) fn smtp_inspect_policy(&self) -> ProtocolInspectPolicy { - self.auditor_config.smtp_inspect_policy + pub(crate) fn smtp_inspect_policy(&self) -> &ProtocolInspectPolicy { + &self.auditor_config.smtp_inspect_policy } #[inline] @@ -138,8 +138,8 @@ impl AuditHandle { } #[inline] - pub(crate) fn imap_inspect_policy(&self) -> ProtocolInspectPolicy { - self.auditor_config.imap_inspect_policy + pub(crate) fn imap_inspect_policy(&self) -> &ProtocolInspectPolicy { + &self.auditor_config.imap_inspect_policy } #[inline] diff --git a/g3proxy/src/config/audit/auditor.rs b/g3proxy/src/config/audit/auditor.rs index 4e200fdeb..97c2f39d2 100644 --- a/g3proxy/src/config/audit/auditor.rs +++ b/g3proxy/src/config/audit/auditor.rs @@ -85,12 +85,24 @@ impl AuditorConfig { tls_stream_dump: None, log_uri_max_chars: 1024, h1_interception: Default::default(), - h2_inspect_policy: ProtocolInspectPolicy::Intercept, + h2_inspect_policy: ProtocolInspectPolicy::builder_with_missing_action( + g3_dpi::ProtocolInspectAction::Intercept, + ) + .build(), h2_interception: Default::default(), - websocket_inspect_policy: ProtocolInspectPolicy::Intercept, - smtp_inspect_policy: ProtocolInspectPolicy::Intercept, + websocket_inspect_policy: ProtocolInspectPolicy::builder_with_missing_action( + g3_dpi::ProtocolInspectAction::Intercept, + ) + .build(), + smtp_inspect_policy: ProtocolInspectPolicy::builder_with_missing_action( + g3_dpi::ProtocolInspectAction::Intercept, + ) + .build(), smtp_interception: Default::default(), - imap_inspect_policy: ProtocolInspectPolicy::Intercept, + imap_inspect_policy: ProtocolInspectPolicy::builder_with_missing_action( + g3_dpi::ProtocolInspectAction::Intercept, + ) + .build(), imap_interception: Default::default(), icap_reqmod_service: None, icap_respmod_service: None, diff --git a/g3proxy/src/inspect/http/v1/upgrade/mod.rs b/g3proxy/src/inspect/http/v1/upgrade/mod.rs index 933f6b828..9e06d8273 100644 --- a/g3proxy/src/inspect/http/v1/upgrade/mod.rs +++ b/g3proxy/src/inspect/http/v1/upgrade/mod.rs @@ -24,7 +24,7 @@ use slog::slog_info; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; -use g3_dpi::Protocol; +use g3_dpi::{Protocol, ProtocolInspectAction}; use g3_http::client::HttpTransparentResponse; use g3_http::server::{HttpTransparentRequest, UriExt}; use g3_http::{HttpBodyReader, HttpBodyType}; @@ -153,9 +153,15 @@ where where CW: AsyncWrite + Unpin, { + let policy_action = match self.req.host.as_ref() { + Some(upstream) => self.ctx.websocket_inspect_action(upstream.host()), + None => self.ctx.websocket_inspect_missing_action(), + }; + let block_websocket = policy_action == ProtocolInspectAction::Block; + let upgrade_token_count = self.req.retain_upgrade(|p| { if matches!(p, HttpUpgradeToken::Websocket) { - return !self.ctx.websocket_inspect_policy().is_block(); + return !block_websocket; } if matches!(p, HttpUpgradeToken::ConnectIp) { return false; diff --git a/g3proxy/src/inspect/http/v2/connect/extended.rs b/g3proxy/src/inspect/http/v2/connect/extended.rs index 707323507..32f7c0487 100644 --- a/g3proxy/src/inspect/http/v2/connect/extended.rs +++ b/g3proxy/src/inspect/http/v2/connect/extended.rs @@ -23,7 +23,7 @@ use h2::{RecvStream, StreamId}; use http::{header, Request, Response, StatusCode, Version}; use slog::slog_info; -use g3_dpi::Protocol; +use g3_dpi::{Protocol, ProtocolInspectAction}; use g3_h2::{H2StreamReader, H2StreamWriter}; use g3_http::server::UriExt; use g3_slog_types::{LtDateTime, LtDuration, LtH2StreamId, LtUpstreamAddr, LtUuid}; @@ -178,7 +178,11 @@ where } }; - if self.ctx.websocket_inspect_policy().is_block() { + let policy_action = match self.upstream.as_ref() { + Some(upstream) => self.ctx.websocket_inspect_action(upstream.host()), + None => self.ctx.websocket_inspect_missing_action(), + }; + if policy_action == ProtocolInspectAction::Block { self.reply_forbidden(clt_send_rsp); intercept_log!(self, "websocket blocked by inspection policy"); return; diff --git a/g3proxy/src/inspect/http/v2/mod.rs b/g3proxy/src/inspect/http/v2/mod.rs index e1eccf90c..95897b8de 100644 --- a/g3proxy/src/inspect/http/v2/mod.rs +++ b/g3proxy/src/inspect/http/v2/mod.rs @@ -24,7 +24,7 @@ use slog::slog_info; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; -use g3_dpi::{Protocol, ProtocolInspectPolicy}; +use g3_dpi::{Protocol, ProtocolInspectAction}; use g3_h2::H2BodyTransfer; use g3_io_ext::OnceBufReader; use g3_slog_types::{LtUpstreamAddr, LtUuid}; @@ -110,15 +110,15 @@ where SC: ServerConfig + Send + Sync + 'static, { pub(crate) async fn intercept(mut self) -> ServerTaskResult<()> { - let r = match self.ctx.h2_inspect_policy() { - ProtocolInspectPolicy::Intercept => self + let r = match self.ctx.h2_inspect_action(self.upstream.host()) { + ProtocolInspectAction::Intercept => self .do_intercept() .await .map_err(|e| InterceptionError::H2(e).into_server_task_error(Protocol::Http2)), #[cfg(feature = "quic")] - ProtocolInspectPolicy::Detour => self.do_detour().await, - ProtocolInspectPolicy::Bypass => self.do_bypass().await, - ProtocolInspectPolicy::Block => self + ProtocolInspectAction::Detour => self.do_detour().await, + ProtocolInspectAction::Bypass => self.do_bypass().await, + ProtocolInspectAction::Block => self .do_block() .await .map_err(|e| InterceptionError::H2(e).into_server_task_error(Protocol::Http2)), diff --git a/g3proxy/src/inspect/imap/mod.rs b/g3proxy/src/inspect/imap/mod.rs index e8aa34d81..e129175ef 100644 --- a/g3proxy/src/inspect/imap/mod.rs +++ b/g3proxy/src/inspect/imap/mod.rs @@ -18,7 +18,7 @@ use anyhow::anyhow; use slog::slog_info; use tokio::io::AsyncWriteExt; -use g3_dpi::ProtocolInspectPolicy; +use g3_dpi::ProtocolInspectAction; use g3_imap_proto::response::ByeResponse; use g3_imap_proto::CommandPipeline; use g3_io_ext::{LineRecvVec, OnceBufReader}; @@ -130,12 +130,12 @@ where } pub(crate) async fn intercept(mut self) -> ServerTaskResult>> { - let r = match self.ctx.imap_inspect_policy() { - ProtocolInspectPolicy::Intercept => self.do_intercept().await, + let r = match self.ctx.imap_inspect_action(self.upstream.host()) { + ProtocolInspectAction::Intercept => self.do_intercept().await, #[cfg(feature = "quic")] - ProtocolInspectPolicy::Detour => self.do_detour().await.map(|_| None), - ProtocolInspectPolicy::Bypass => self.do_bypass().await.map(|_| None), - ProtocolInspectPolicy::Block => self.do_block().await.map(|_| None), + ProtocolInspectAction::Detour => self.do_detour().await.map(|_| None), + ProtocolInspectAction::Bypass => self.do_bypass().await.map(|_| None), + ProtocolInspectAction::Block => self.do_block().await.map(|_| None), }; match r { Ok(obj) => { diff --git a/g3proxy/src/inspect/mod.rs b/g3proxy/src/inspect/mod.rs index c93b368c5..af581d16b 100644 --- a/g3proxy/src/inspect/mod.rs +++ b/g3proxy/src/inspect/mod.rs @@ -25,9 +25,9 @@ use uuid::Uuid; use g3_daemon::server::ServerQuitPolicy; use g3_dpi::{ H1InterceptionConfig, H2InterceptionConfig, ImapInterceptionConfig, MaybeProtocol, - ProtocolInspectPolicy, ProtocolInspector, SmtpInterceptionConfig, + ProtocolInspectAction, ProtocolInspector, SmtpInterceptionConfig, }; -use g3_types::net::OpensslClientConfig; +use g3_types::net::{Host, OpensslClientConfig}; use crate::audit::AuditHandle; use crate::auth::{User, UserForbiddenStats, UserSite}; @@ -263,8 +263,17 @@ impl StreamInspectContext { } #[inline] - fn h2_inspect_policy(&self) -> ProtocolInspectPolicy { - self.audit_handle.h2_inspect_policy() + fn h2_inspect_action(&self, host: &Host) -> ProtocolInspectAction { + match self.audit_handle.h2_inspect_policy().check(host) { + (true, policy_action) => policy_action, + (false, missing_policy_action) => missing_policy_action, + } + } + + #[inline] + #[allow(dead_code)] + fn h2_inspect_missing_action(&self) -> ProtocolInspectAction { + self.audit_handle.h2_inspect_policy().missing_action() } #[inline] @@ -281,13 +290,32 @@ impl StreamInspectContext { } #[inline] - fn websocket_inspect_policy(&self) -> ProtocolInspectPolicy { - self.audit_handle.websocket_inspect_policy() + fn websocket_inspect_action(&self, host: &Host) -> ProtocolInspectAction { + match self.audit_handle.websocket_inspect_policy().check(host) { + (true, policy_action) => policy_action, + (false, missing_policy_action) => missing_policy_action, + } + } + + #[inline] + fn websocket_inspect_missing_action(&self) -> ProtocolInspectAction { + self.audit_handle + .websocket_inspect_policy() + .missing_action() } #[inline] - fn smtp_inspect_policy(&self) -> ProtocolInspectPolicy { - self.audit_handle.smtp_inspect_policy() + fn smtp_inspect_action(&self, host: &Host) -> ProtocolInspectAction { + match self.audit_handle.smtp_inspect_policy().check(host) { + (true, policy_action) => policy_action, + (false, missing_policy_action) => missing_policy_action, + } + } + + #[inline] + #[allow(dead_code)] + fn smtp_inspect_missing_action(&self) -> ProtocolInspectAction { + self.audit_handle.smtp_inspect_policy().missing_action() } #[inline] @@ -296,8 +324,17 @@ impl StreamInspectContext { } #[inline] - fn imap_inspect_policy(&self) -> ProtocolInspectPolicy { - self.audit_handle.imap_inspect_policy() + fn imap_inspect_action(&self, host: &Host) -> ProtocolInspectAction { + match self.audit_handle.imap_inspect_policy().check(host) { + (true, policy_action) => policy_action, + (false, missing_policy_action) => missing_policy_action, + } + } + + #[inline] + #[allow(dead_code)] + fn imap_inspect_missing_action(&self) -> ProtocolInspectAction { + self.audit_handle.imap_inspect_policy().missing_action() } #[inline] diff --git a/g3proxy/src/inspect/smtp/mod.rs b/g3proxy/src/inspect/smtp/mod.rs index 3f6ad9c89..5abd19ed1 100644 --- a/g3proxy/src/inspect/smtp/mod.rs +++ b/g3proxy/src/inspect/smtp/mod.rs @@ -18,7 +18,7 @@ use anyhow::anyhow; use slog::slog_info; use tokio::io::AsyncWriteExt; -use g3_dpi::ProtocolInspectPolicy; +use g3_dpi::ProtocolInspectAction; use g3_io_ext::{LineRecvBuf, OnceBufReader}; use g3_slog_types::{LtHost, LtUpstreamAddr, LtUuid}; use g3_smtp_proto::command::Command; @@ -121,12 +121,12 @@ where } pub(crate) async fn intercept(mut self) -> ServerTaskResult>> { - let r = match self.ctx.smtp_inspect_policy() { - ProtocolInspectPolicy::Intercept => self.do_intercept().await, + let r = match self.ctx.smtp_inspect_action(self.upstream.host()) { + ProtocolInspectAction::Intercept => self.do_intercept().await, #[cfg(feature = "quic")] - ProtocolInspectPolicy::Detour => self.do_detour().await.map(|_| None), - ProtocolInspectPolicy::Bypass => self.do_bypass().await.map(|_| None), - ProtocolInspectPolicy::Block => self.do_block().await.map(|_| None), + ProtocolInspectAction::Detour => self.do_detour().await.map(|_| None), + ProtocolInspectAction::Bypass => self.do_bypass().await.map(|_| None), + ProtocolInspectAction::Block => self.do_block().await.map(|_| None), }; match r { Ok(obj) => { diff --git a/g3proxy/src/inspect/tls/mod.rs b/g3proxy/src/inspect/tls/mod.rs index ae46b45dd..168384f20 100644 --- a/g3proxy/src/inspect/tls/mod.rs +++ b/g3proxy/src/inspect/tls/mod.rs @@ -22,7 +22,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::runtime::Handle; use g3_cert_agent::CertAgentHandle; -use g3_dpi::Protocol; +use g3_dpi::{Protocol, ProtocolInspectAction}; use g3_io_ext::{AsyncStream, FlexBufReader, OnceBufReader}; use g3_slog_types::{LtUpstreamAddr, LtUuid}; use g3_types::net::{ @@ -167,11 +167,14 @@ impl TlsInterceptObject { fn retain_alpn_protocol(&self, p: &[u8]) -> bool { if p == AlpnProtocol::Http2.identification_sequence() { - return !self.ctx.h2_inspect_policy().is_block(); + return ProtocolInspectAction::Block + != self.ctx.h2_inspect_action(self.upstream.host()); } else if p == AlpnProtocol::Smtp.identification_sequence() { - return !self.ctx.smtp_inspect_policy().is_block(); + return ProtocolInspectAction::Block + != self.ctx.smtp_inspect_action(self.upstream.host()); } else if p == AlpnProtocol::Imap.identification_sequence() { - return !self.ctx.imap_inspect_policy().is_block(); + return ProtocolInspectAction::Block + != self.ctx.imap_inspect_action(self.upstream.host()); } true } diff --git a/g3proxy/src/inspect/websocket/h1.rs b/g3proxy/src/inspect/websocket/h1.rs index 37eb7f7f7..090380702 100644 --- a/g3proxy/src/inspect/websocket/h1.rs +++ b/g3proxy/src/inspect/websocket/h1.rs @@ -18,7 +18,7 @@ use anyhow::anyhow; use slog::slog_info; use tokio::io::AsyncWriteExt; -use g3_dpi::ProtocolInspectPolicy; +use g3_dpi::ProtocolInspectAction; use g3_io_ext::LimitedWriteExt; use g3_slog_types::{LtHttpHeaderValue, LtUpstreamAddr, LtUuid}; use g3_types::net::{UpstreamAddr, WebSocketNotes}; @@ -90,12 +90,12 @@ impl H1WebsocketInterceptObject { } pub(crate) async fn intercept(mut self) -> ServerTaskResult<()> { - let r = match self.ctx.websocket_inspect_policy() { - ProtocolInspectPolicy::Intercept => self.do_intercept().await, + let r = match self.ctx.websocket_inspect_action(self.upstream.host()) { + ProtocolInspectAction::Intercept => self.do_intercept().await, #[cfg(feature = "quic")] - ProtocolInspectPolicy::Detour => self.do_detour().await, - ProtocolInspectPolicy::Bypass => self.do_bypass().await, - ProtocolInspectPolicy::Block => self.do_block().await, + ProtocolInspectAction::Detour => self.do_detour().await, + ProtocolInspectAction::Bypass => self.do_bypass().await, + ProtocolInspectAction::Block => self.do_block().await, }; match r { Ok(_) => { diff --git a/g3proxy/src/inspect/websocket/h2.rs b/g3proxy/src/inspect/websocket/h2.rs index 7694653d9..c8bdc7fc6 100644 --- a/g3proxy/src/inspect/websocket/h2.rs +++ b/g3proxy/src/inspect/websocket/h2.rs @@ -19,7 +19,7 @@ use bytes::Bytes; use h2::{RecvStream, SendStream}; use slog::slog_info; -use g3_dpi::ProtocolInspectPolicy; +use g3_dpi::ProtocolInspectAction; use g3_h2::{H2StreamReader, H2StreamWriter}; use g3_slog_types::{LtHttpHeaderValue, LtUpstreamAddr, LtUuid}; use g3_types::net::{UpstreamAddr, WebSocketNotes}; @@ -74,12 +74,12 @@ impl H2WebsocketInterceptObject { ups_r: RecvStream, ups_w: SendStream, ) { - let r = match self.ctx.websocket_inspect_policy() { - ProtocolInspectPolicy::Intercept => self.do_intercept(clt_r, clt_w, ups_r, ups_w).await, + let r = match self.ctx.websocket_inspect_action(self.upstream.host()) { + ProtocolInspectAction::Intercept => self.do_intercept(clt_r, clt_w, ups_r, ups_w).await, #[cfg(feature = "quic")] - ProtocolInspectPolicy::Detour => self.do_detour(clt_r, clt_w, ups_r, ups_w).await, - ProtocolInspectPolicy::Bypass => self.do_bypass(clt_r, clt_w, ups_r, ups_w).await, - ProtocolInspectPolicy::Block => self.do_block(clt_w, ups_w).await, + ProtocolInspectAction::Detour => self.do_detour(clt_r, clt_w, ups_r, ups_w).await, + ProtocolInspectAction::Bypass => self.do_bypass(clt_r, clt_w, ups_r, ups_w).await, + ProtocolInspectAction::Block => self.do_block(clt_w, ups_w).await, }; match r { Ok(_) => { diff --git a/lib/g3-dpi/src/config/mod.rs b/lib/g3-dpi/src/config/mod.rs index d87309990..013f45d07 100644 --- a/lib/g3-dpi/src/config/mod.rs +++ b/lib/g3-dpi/src/config/mod.rs @@ -14,9 +14,13 @@ * limitations under the License. */ +use std::fmt; use std::str::FromStr; use std::time::Duration; +use g3_types::acl::ActionContract; +use g3_types::acl_set::AclDstHostRuleSet; + mod size_limit; pub use size_limit::ProtocolInspectionSizeLimit; @@ -30,9 +34,10 @@ pub use smtp::SmtpInterceptionConfig; mod imap; pub use imap::ImapInterceptionConfig; -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] -pub enum ProtocolInspectPolicy { - #[default] +pub type ProtocolInspectPolicy = AclDstHostRuleSet; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub enum ProtocolInspectAction { Intercept, #[cfg(feature = "quic")] Detour, @@ -40,24 +45,55 @@ pub enum ProtocolInspectPolicy { Block, } -impl ProtocolInspectPolicy { +impl ProtocolInspectAction { #[inline] - pub fn is_block(&self) -> bool { - matches!(self, ProtocolInspectPolicy::Block) + fn as_str(&self) -> &'static str { + self.serialize() + } +} + +impl fmt::Display for ProtocolInspectAction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(ProtocolInspectAction::as_str(self)) } } -impl FromStr for ProtocolInspectPolicy { +impl FromStr for ProtocolInspectAction { type Err = (); + #[inline] fn from_str(s: &str) -> Result { + Self::deserialize(s).map_err(|_| ()) + } +} + +impl ActionContract for ProtocolInspectAction { + fn default_forbid() -> Self { + Self::Block + } + + fn default_permit() -> Self { + Self::Intercept + } + + fn serialize(&self) -> &'static str { + match self { + Self::Intercept => "intercept", + Self::Block => "block", + Self::Bypass => "bypass", + #[cfg(feature = "quic")] + Self::Detour => "detour", + } + } + + fn deserialize(s: &str) -> Result { match s.to_lowercase().as_str() { - "intercept" => Ok(ProtocolInspectPolicy::Intercept), + "intercept" => Ok(ProtocolInspectAction::Intercept), #[cfg(feature = "quic")] - "detour" => Ok(ProtocolInspectPolicy::Detour), - "bypass" => Ok(ProtocolInspectPolicy::Bypass), - "block" => Ok(ProtocolInspectPolicy::Block), - _ => Err(()), + "detour" => Ok(ProtocolInspectAction::Detour), + "bypass" => Ok(ProtocolInspectAction::Bypass), + "block" => Ok(ProtocolInspectAction::Block), + _ => Err(s), } } } diff --git a/lib/g3-dpi/src/lib.rs b/lib/g3-dpi/src/lib.rs index 7c994a5fd..7dffcd1ab 100644 --- a/lib/g3-dpi/src/lib.rs +++ b/lib/g3-dpi/src/lib.rs @@ -25,8 +25,9 @@ pub use protocol::{ mod config; pub use config::{ - H1InterceptionConfig, H2InterceptionConfig, ImapInterceptionConfig, ProtocolInspectPolicy, - ProtocolInspectionConfig, ProtocolInspectionSizeLimit, SmtpInterceptionConfig, + H1InterceptionConfig, H2InterceptionConfig, ImapInterceptionConfig, ProtocolInspectAction, + ProtocolInspectPolicy, ProtocolInspectionConfig, ProtocolInspectionSizeLimit, + SmtpInterceptionConfig, }; pub mod parser; diff --git a/lib/g3-json/src/value/acl_set/dst_host.rs b/lib/g3-json/src/value/acl_set/dst_host.rs index 2004e540b..f8db7b0b1 100644 --- a/lib/g3-json/src/value/acl_set/dst_host.rs +++ b/lib/g3-json/src/value/acl_set/dst_host.rs @@ -21,12 +21,7 @@ use g3_types::acl_set::AclDstHostRuleSetBuilder; pub fn as_dst_host_rule_set_builder(value: &Value) -> anyhow::Result { if let Value::Object(map) = value { - let mut builder = AclDstHostRuleSetBuilder { - exact: None, - child: None, - regex: None, - subnet: None, - }; + let mut builder = AclDstHostRuleSetBuilder::default(); for (k, v) in map { match crate::key::normalize(k).as_str() { "exact_match" | "exact" => { diff --git a/lib/g3-types/src/acl/a_hash.rs b/lib/g3-types/src/acl/a_hash.rs index 45f072758..f9e7ae4a8 100644 --- a/lib/g3-types/src/acl/a_hash.rs +++ b/lib/g3-types/src/acl/a_hash.rs @@ -19,31 +19,33 @@ use std::hash::Hash; use ahash::AHashMap; -use super::AclAction; +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclAHashRule +pub struct AclAHashRule where K: Hash + Eq, { - inner: AHashMap, - missed_action: AclAction, + inner: AHashMap, + missed_action: Action, } -impl Default for AclAHashRule +impl Default for AclAHashRule where K: Hash + Eq, + Action: ActionContract, { fn default() -> Self { - Self::new(AclAction::Forbid) + Self::new(Action::default_forbid()) } } -impl AclAHashRule +impl AclAHashRule where K: Hash + Eq, + Action: ActionContract, { - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclAHashRule { inner: AHashMap::new(), missed_action, @@ -51,16 +53,16 @@ where } #[inline] - pub fn add_node(&mut self, node: K, action: AclAction) { + pub fn add_node(&mut self, node: K, action: Action) { self.inner.insert(node, action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } - pub fn check(&self, node: &Q) -> (bool, AclAction) + pub fn check(&self, node: &Q) -> (bool, Action) where K: Borrow, Q: Hash + Eq + ?Sized, diff --git a/lib/g3-types/src/acl/child_domain.rs b/lib/g3-types/src/acl/child_domain.rs index 8dcb2ca5a..3d84f8fd7 100644 --- a/lib/g3-types/src/acl/child_domain.rs +++ b/lib/g3-types/src/acl/child_domain.rs @@ -14,44 +14,45 @@ * limitations under the License. */ -use super::{AclAction, AclRadixTrieRule, AclRadixTrieRuleBuilder}; +use super::{AclAction, AclRadixTrieRule, AclRadixTrieRuleBuilder, ActionContract}; use crate::resolve::reverse_idna_domain; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclChildDomainRuleBuilder(AclRadixTrieRuleBuilder); +pub struct AclChildDomainRuleBuilder(AclRadixTrieRuleBuilder); -impl AclChildDomainRuleBuilder { +impl AclChildDomainRuleBuilder { #[inline] - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclChildDomainRuleBuilder(AclRadixTrieRuleBuilder::new(missed_action)) } #[inline] - pub fn add_node(&mut self, domain: &str, action: AclAction) { + pub fn add_node(&mut self, domain: &str, action: Action) { self.0.add_node(reverse_idna_domain(domain), action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.0.set_missed_action(action); } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.0.missed_action() } #[inline] - pub fn build(&self) -> AclChildDomainRule { + pub fn build(&self) -> AclChildDomainRule { AclChildDomainRule(self.0.build()) } } -pub struct AclChildDomainRule(AclRadixTrieRule); +#[derive(Clone)] +pub struct AclChildDomainRule(AclRadixTrieRule); -impl AclChildDomainRule { +impl AclChildDomainRule { #[inline] - pub fn check(&self, host: &str) -> (bool, AclAction) { + pub fn check(&self, host: &str) -> (bool, Action) { let s = reverse_idna_domain(host); self.0.check(&s) } diff --git a/lib/g3-types/src/acl/exact_host.rs b/lib/g3-types/src/acl/exact_host.rs index 83565f817..e5525adec 100644 --- a/lib/g3-types/src/acl/exact_host.rs +++ b/lib/g3-types/src/acl/exact_host.rs @@ -17,19 +17,19 @@ use std::net::IpAddr; use std::sync::Arc; -use super::{AclAHashRule, AclAction}; +use super::{AclAHashRule, AclAction, ActionContract}; use crate::net::Host; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclExactHostRule { - missed_action: AclAction, - domain: AclAHashRule>, - ip: AclAHashRule, +pub struct AclExactHostRule { + missed_action: Action, + domain: AclAHashRule, Action>, + ip: AclAHashRule, } -impl AclExactHostRule { +impl AclExactHostRule { #[inline] - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclExactHostRule { missed_action, domain: AclAHashRule::new(missed_action), @@ -38,16 +38,16 @@ impl AclExactHostRule { } #[inline] - pub fn add_domain(&mut self, domain: Arc, action: AclAction) { + pub fn add_domain(&mut self, domain: Arc, action: Action) { self.domain.add_node(domain, action); } #[inline] - pub fn add_ip(&mut self, ip: IpAddr, action: AclAction) { + pub fn add_ip(&mut self, ip: IpAddr, action: Action) { self.ip.add_node(ip, action); } - pub fn add_host(&mut self, host: Host, action: AclAction) { + pub fn add_host(&mut self, host: Host, action: Action) { match host { Host::Ip(ip) => self.add_ip(ip, action), Host::Domain(domain) => self.add_domain(domain, action), @@ -55,24 +55,24 @@ impl AclExactHostRule { } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; self.domain.set_missed_action(action); self.ip.set_missed_action(action); } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } #[inline] - pub fn check_domain(&self, domain: &str) -> (bool, AclAction) { + pub fn check_domain(&self, domain: &str) -> (bool, Action) { self.domain.check(domain) } #[inline] - pub fn check_ip(&self, ip: &IpAddr) -> (bool, AclAction) { + pub fn check_ip(&self, ip: &IpAddr) -> (bool, Action) { self.ip.check(ip) } } diff --git a/lib/g3-types/src/acl/exact_port.rs b/lib/g3-types/src/acl/exact_port.rs index 7e73f5a00..71cd5e5fb 100644 --- a/lib/g3-types/src/acl/exact_port.rs +++ b/lib/g3-types/src/acl/exact_port.rs @@ -16,42 +16,42 @@ use std::ops::RangeInclusive; -use super::{AclAction, AclFxHashRule}; +use super::{AclAction, AclFxHashRule, ActionContract}; use crate::net::Ports; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclExactPortRule(AclFxHashRule); +pub struct AclExactPortRule(AclFxHashRule); -impl AclExactPortRule { +impl AclExactPortRule { #[inline] - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclExactPortRule(AclFxHashRule::new(missed_action)) } - pub fn add_port_range(&mut self, port_range: RangeInclusive, action: AclAction) { + pub fn add_port_range(&mut self, port_range: RangeInclusive, action: Action) { for port in port_range { self.0.add_node(port, action); } } - pub fn add_ports(&mut self, ports: Ports, action: AclAction) { + pub fn add_ports(&mut self, ports: Ports, action: Action) { for port in ports { self.0.add_node(port, action); } } #[inline] - pub fn add_port(&mut self, port: u16, action: AclAction) { + pub fn add_port(&mut self, port: u16, action: Action) { self.0.add_node(port, action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.0.set_missed_action(action); } #[inline] - pub fn check_port(&self, port: &u16) -> (bool, AclAction) { + pub fn check_port(&self, port: &u16) -> (bool, Action) { self.0.check(port) } } diff --git a/lib/g3-types/src/acl/fx_hash.rs b/lib/g3-types/src/acl/fx_hash.rs index 5ef1ccba8..89e0cc593 100644 --- a/lib/g3-types/src/acl/fx_hash.rs +++ b/lib/g3-types/src/acl/fx_hash.rs @@ -19,31 +19,33 @@ use std::hash::Hash; use rustc_hash::FxHashMap; -use super::AclAction; +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclFxHashRule +pub struct AclFxHashRule where K: Hash + Eq, { - inner: FxHashMap, - missed_action: AclAction, + inner: FxHashMap, + missed_action: Action, } -impl Default for AclFxHashRule +impl Default for AclFxHashRule where K: Hash + Eq, + Action: ActionContract, { fn default() -> Self { - Self::new(AclAction::Forbid) + Self::new(Action::default_forbid()) } } -impl AclFxHashRule +impl AclFxHashRule where K: Hash + Eq, + Action: ActionContract, { - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclFxHashRule { inner: FxHashMap::default(), missed_action, @@ -51,16 +53,16 @@ where } #[inline] - pub fn add_node(&mut self, node: K, action: AclAction) { + pub fn add_node(&mut self, node: K, action: Action) { self.inner.insert(node, action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } - pub fn check(&self, node: &Q) -> (bool, AclAction) + pub fn check(&self, node: &Q) -> (bool, Action) where K: Borrow, Q: Hash + Eq + ?Sized, diff --git a/lib/g3-types/src/acl/mod.rs b/lib/g3-types/src/acl/mod.rs index 996dd884a..ee5e1a9c1 100644 --- a/lib/g3-types/src/acl/mod.rs +++ b/lib/g3-types/src/acl/mod.rs @@ -40,7 +40,17 @@ pub use proxy_request::AclProxyRequestRule; pub use regex_set::{AclRegexSetRule, AclRegexSetRuleBuilder}; pub use user_agent::AclUserAgentRule; -#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd)] +pub trait ActionContract: + Clone + Copy + PartialEq + Eq + PartialOrd + Ord + std::hash::Hash +{ + fn default_forbid() -> Self; + fn default_permit() -> Self; + + fn serialize(&self) -> &'static str; + fn deserialize(s: &str) -> Result; +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] pub enum AclAction { Permit, PermitAndLog, @@ -49,27 +59,39 @@ pub enum AclAction { } impl AclAction { - #[must_use] pub fn restrict(self, other: AclAction) -> AclAction { - if other > self { - other - } else { - self - } + other.max(self) } - pub fn strict_than(&self, other: AclAction) -> bool { + pub fn strict_than(self, other: AclAction) -> bool { self.gt(&other) } - pub const fn forbid_early(&self) -> bool { + pub fn forbid_early(&self) -> bool { match self { AclAction::ForbidAndLog | AclAction::Forbid => true, AclAction::PermitAndLog | AclAction::Permit => false, } } +} - pub const fn as_str(&self) -> &'static str { +impl AclAction { + #[inline] + fn as_str(&self) -> &'static str { + self.serialize() + } +} + +impl ActionContract for AclAction { + fn default_permit() -> AclAction { + AclAction::Permit + } + + fn default_forbid() -> AclAction { + AclAction::Forbid + } + + fn serialize(&self) -> &'static str { match self { AclAction::Permit => "Permit", AclAction::PermitAndLog => "PermitAndLog", @@ -77,6 +99,16 @@ impl AclAction { AclAction::ForbidAndLog => "ForbidAndLog", } } + + fn deserialize(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "permit" | "allow" | "accept" => Ok(AclAction::Permit), + "permit_log" | "allow_log" | "accept_log" => Ok(AclAction::PermitAndLog), + "forbid" | "deny" | "reject" => Ok(AclAction::Forbid), + "forbid_log" | "deny_log" | "reject_log" => Ok(AclAction::ForbidAndLog), + _ => Err(s), + } + } } impl fmt::Display for AclAction { @@ -88,14 +120,9 @@ impl fmt::Display for AclAction { impl FromStr for AclAction { type Err = (); + #[inline] fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "permit" | "allow" | "accept" => Ok(AclAction::Permit), - "permit_log" | "allow_log" | "accept_log" => Ok(AclAction::PermitAndLog), - "forbid" | "deny" | "reject" => Ok(AclAction::Forbid), - "forbid_log" | "deny_log" | "reject_log" => Ok(AclAction::ForbidAndLog), - _ => Err(()), - } + AclAction::deserialize(s).map_err(|_| ()) } } diff --git a/lib/g3-types/src/acl/network.rs b/lib/g3-types/src/acl/network.rs index 276563f94..dc1a41b8f 100644 --- a/lib/g3-types/src/acl/network.rs +++ b/lib/g3-types/src/acl/network.rs @@ -21,100 +21,81 @@ use std::sync::LazyLock; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use super::AclAction; - -static DEFAULT_EGRESS_RULE: LazyLock> = LazyLock::new(|| { - let mut m = HashMap::new(); - // forbid ipv4 unspecified 0.0.0.0/32 by default - m.insert( - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 32).unwrap(), - AclAction::Forbid, - ); - // forbid ipv4 loopback 127.0.0.0/8 by default - m.insert( - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 0)), 8).unwrap(), - AclAction::Forbid, - ); - // forbid ipv4 link-local 169.254.0.0/16 by default - m.insert( - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 0)), 16).unwrap(), - AclAction::Forbid, - ); - // forbid ipv6 unspecified ::/128 by default - m.insert( - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 128).unwrap(), - AclAction::Forbid, - ); - // forbid ipv6 loopback ::1/128 by default - m.insert( - IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 128).unwrap(), - AclAction::Forbid, - ); - // forbid ipv6 link-local fe80::/10 by default - m.insert( - IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0)), 10).unwrap(), - AclAction::Forbid, - ); - // forbid ipv6 discard-only 100::/64 by default - m.insert( - IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0x0100, 0, 0, 0, 0, 0, 0, 0)), 64).unwrap(), - AclAction::Forbid, - ); - m -}); - -static DEFAULT_INGRESS_RULE: LazyLock> = LazyLock::new(|| { - let mut m = HashMap::new(); - // permit ipv4 loopback 127.0.0.1/32 by default - m.insert( - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 32).unwrap(), - AclAction::Permit, - ); - // permit ipv6 loopback ::1/128 by default - m.insert( - IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 128).unwrap(), - AclAction::Permit, - ); - m -}); +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclNetworkRuleBuilder { - inner: HashMap, - missed_action: AclAction, +pub struct AclNetworkRuleBuilder { + inner: HashMap, + missed_action: Action, } -impl AclNetworkRuleBuilder { - pub fn new_egress(missed_action: AclAction) -> Self { - AclNetworkRuleBuilder { - inner: DEFAULT_EGRESS_RULE.clone(), +impl AclNetworkRuleBuilder { + pub fn new_egress(missed_action: Action) -> Self { + static DEFAULT_EGRESS_RULE: LazyLock> = LazyLock::new(|| { + vec![ + // forbid ipv4 unspecified 0.0.0.0/32 by default + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 32).unwrap(), + // forbid ipv4 loopback 127.0.0.0/8 by default + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 0)), 8).unwrap(), + // forbid ipv4 link-local 169.254.0.0/16 by default + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 0)), 16).unwrap(), + // forbid ipv6 unspecified ::/128 by default + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 128).unwrap(), + // forbid ipv6 loopback ::1/128 by default + IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 128).unwrap(), + // forbid ipv6 link-local fe80::/10 by default + IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0)), 10).unwrap(), + // forbid ipv6 discard-only 100::/64 by default + IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0x0100, 0, 0, 0, 0, 0, 0, 0)), 64).unwrap(), + ] + }); + let v = DEFAULT_EGRESS_RULE.clone(); + let mut inner = HashMap::with_capacity(v.len()); + for ip_network in v { + inner.insert(ip_network, Action::default_forbid()); + } + Self { + inner, missed_action, } } - pub fn new_ingress(missed_action: AclAction) -> Self { - AclNetworkRuleBuilder { - inner: DEFAULT_INGRESS_RULE.clone(), + pub fn new_ingress(missed_action: Action) -> Self { + static DEFAULT_INGRESS_RULE: LazyLock> = LazyLock::new(|| { + vec![ + // permit ipv4 loopback 127.0.0.1/32 by default + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 32).unwrap(), + // permit ipv6 loopback ::1/128 by default + IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 128).unwrap(), + ] + }); + let v = DEFAULT_INGRESS_RULE.clone(); + let mut inner = HashMap::with_capacity(v.len()); + for ip_network in v { + inner.insert(ip_network, Action::default_permit()); + } + Self { + inner, missed_action, } } #[inline] - pub fn add_network(&mut self, network: IpNetwork, action: AclAction) { + pub fn add_network(&mut self, network: IpNetwork, action: Action) { self.inner.insert(network, action); } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } - pub fn build(&self) -> AclNetworkRule { + pub fn build(&self) -> AclNetworkRule { let mut inner = IpNetworkTable::new(); for (net, action) in &self.inner { inner.insert(*net, *action); @@ -126,13 +107,29 @@ impl AclNetworkRuleBuilder { } } -pub struct AclNetworkRule { - inner: IpNetworkTable, - default_action: AclAction, +pub struct AclNetworkRule { + inner: IpNetworkTable, + default_action: Action, +} + +impl Clone for AclNetworkRule { + fn clone(&self) -> Self { + Self { + inner: { + let (ipv4_size, ipv6_size) = self.inner.len(); + let mut table = IpNetworkTable::with_capacity(ipv4_size, ipv6_size); + for (k, v) in self.inner.iter() { + table.insert(k, *v); + } + table + }, + default_action: self.default_action, + } + } } -impl AclNetworkRule { - pub fn check(&self, ip: IpAddr) -> (bool, AclAction) { +impl AclNetworkRule { + pub fn check(&self, ip: IpAddr) -> (bool, Action) { if let Some((_, action)) = self.inner.longest_match(ip) { (true, *action) } else { diff --git a/lib/g3-types/src/acl/proxy_request.rs b/lib/g3-types/src/acl/proxy_request.rs index c4f19a9ee..7ff5b44fb 100644 --- a/lib/g3-types/src/acl/proxy_request.rs +++ b/lib/g3-types/src/acl/proxy_request.rs @@ -14,18 +14,18 @@ * limitations under the License. */ -use super::{AclAHashRule, AclAction}; +use super::{AclAHashRule, AclAction, ActionContract}; use crate::net::ProxyRequestType; #[derive(Clone)] -pub struct AclProxyRequestRule { - missed_action: AclAction, - request: AclAHashRule, +pub struct AclProxyRequestRule { + missed_action: Action, + request: AclAHashRule, } -impl AclProxyRequestRule { +impl AclProxyRequestRule { #[inline] - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclProxyRequestRule { missed_action, request: AclAHashRule::new(missed_action), @@ -33,23 +33,23 @@ impl AclProxyRequestRule { } #[inline] - pub fn add_request_type(&mut self, request: ProxyRequestType, action: AclAction) { + pub fn add_request_type(&mut self, request: ProxyRequestType, action: Action) { self.request.add_node(request, action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; self.request.set_missed_action(action); } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } #[inline] - pub fn check_request(&self, request: &ProxyRequestType) -> (bool, AclAction) { + pub fn check_request(&self, request: &ProxyRequestType) -> (bool, Action) { self.request.check(request) } } diff --git a/lib/g3-types/src/acl/radix_trie.rs b/lib/g3-types/src/acl/radix_trie.rs index 6e6c221c1..961dc82a3 100644 --- a/lib/g3-types/src/acl/radix_trie.rs +++ b/lib/g3-types/src/acl/radix_trie.rs @@ -20,22 +20,23 @@ use std::hash::Hash; use ahash::AHashMap; use radix_trie::{Trie, TrieKey}; -use super::AclAction; +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclRadixTrieRuleBuilder +pub struct AclRadixTrieRuleBuilder where K: TrieKey + Hash, { - inner: AHashMap, - missed_action: AclAction, + inner: AHashMap, + missed_action: Action, } -impl AclRadixTrieRuleBuilder +impl AclRadixTrieRuleBuilder where K: TrieKey + Hash + Clone, + Action: ActionContract, { - pub fn new(missed_action: AclAction) -> Self { + pub fn new(missed_action: Action) -> Self { AclRadixTrieRuleBuilder { inner: AHashMap::new(), missed_action, @@ -43,21 +44,21 @@ where } #[inline] - pub fn add_node(&mut self, node: K, action: AclAction) { + pub fn add_node(&mut self, node: K, action: Action) { self.inner.insert(node, action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } - pub fn build(&self) -> AclRadixTrieRule { + pub fn build(&self) -> AclRadixTrieRule { let mut trie = Trie::new(); for (k, v) in &self.inner { @@ -71,13 +72,14 @@ where } } -pub struct AclRadixTrieRule { - inner: Trie, - missed_action: AclAction, +#[derive(Clone)] +pub struct AclRadixTrieRule { + inner: Trie, + missed_action: Action, } -impl AclRadixTrieRule { - pub fn check(&self, key: &Q) -> (bool, AclAction) +impl AclRadixTrieRule { + pub fn check(&self, key: &Q) -> (bool, Action) where K: Borrow, Q: TrieKey, diff --git a/lib/g3-types/src/acl/regex_set.rs b/lib/g3-types/src/acl/regex_set.rs index 8b3997a12..24a2f7b7e 100644 --- a/lib/g3-types/src/acl/regex_set.rs +++ b/lib/g3-types/src/acl/regex_set.rs @@ -14,114 +14,76 @@ * limitations under the License. */ -use std::collections::HashMap; - use regex::{Regex, RegexSet}; +use rustc_hash::FxHashMap; -use super::AclAction; +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclRegexSetRuleBuilder { - inner: HashMap, - missed_action: AclAction, +pub struct AclRegexSetRuleBuilder { + inner: FxHashMap, + missed_action: Action, } -impl Default for AclRegexSetRuleBuilder { +impl Default for AclRegexSetRuleBuilder { fn default() -> Self { - Self::new(AclAction::Forbid) + Self::new(Action::default_forbid()) } } -impl AclRegexSetRuleBuilder { - pub fn new(missed_action: AclAction) -> Self { +impl AclRegexSetRuleBuilder { + pub fn new(missed_action: Action) -> Self { AclRegexSetRuleBuilder { - inner: HashMap::new(), + inner: FxHashMap::default(), missed_action, } } #[inline] - pub fn add_regex(&mut self, regex: &Regex, action: AclAction) { + pub fn add_regex(&mut self, regex: &Regex, action: Action) { self.inner.insert(regex.as_str().to_string(), action); } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } - pub fn build(&self) -> AclRegexSetRule { - let mut forbid_log_v = Vec::new(); - let mut forbid_v = Vec::new(); - let mut permit_log_v = Vec::new(); - let mut permit_v = Vec::new(); + pub fn build(&self) -> AclRegexSetRule { + let mut set_map: FxHashMap> = FxHashMap::default(); for (r, action) in &self.inner { - match action { - AclAction::ForbidAndLog => forbid_log_v.push(r.as_str()), - AclAction::Forbid => forbid_v.push(r.as_str()), - AclAction::PermitAndLog => permit_log_v.push(r.as_str()), - AclAction::Permit => permit_v.push(r.as_str()), - } - } - - fn build_rs_from_vec(v: &[&str]) -> Option { - if v.is_empty() { - None - } else { - Some(RegexSet::new(v).unwrap()) - } + set_map.entry(*action).or_default().push(r.as_str()); } AclRegexSetRule { - forbid_log: build_rs_from_vec(&forbid_log_v), - forbid: build_rs_from_vec(&forbid_v), - permit_log: build_rs_from_vec(&permit_log_v), - permit: build_rs_from_vec(&permit_v), + set_map: set_map + .into_iter() + .map(|(k, v)| (k, RegexSet::new(v).unwrap())) + .collect(), missed_action: self.missed_action, } } } -pub struct AclRegexSetRule { - forbid_log: Option, - forbid: Option, - permit_log: Option, - permit: Option, - missed_action: AclAction, +#[derive(Clone)] +pub struct AclRegexSetRule { + set_map: FxHashMap, + missed_action: Action, } -impl AclRegexSetRule { - pub fn check(&self, text: &str) -> (bool, AclAction) { - if let Some(rs) = &self.forbid_log { - if rs.is_match(text) { - return (true, AclAction::ForbidAndLog); - } - } - - if let Some(rs) = &self.forbid { - if rs.is_match(text) { - return (true, AclAction::Forbid); +impl AclRegexSetRule { + pub fn check(&self, text: &str) -> (bool, Action) { + for (action, set) in &self.set_map { + if set.is_match(text) { + return (true, *action); } } - - if let Some(rs) = &self.permit_log { - if rs.is_match(text) { - return (true, AclAction::PermitAndLog); - } - } - - if let Some(rs) = &self.permit { - if rs.is_match(text) { - return (true, AclAction::Permit); - } - } - (false, self.missed_action) } } diff --git a/lib/g3-types/src/acl/user_agent.rs b/lib/g3-types/src/acl/user_agent.rs index f027e1a0d..b5beef4a5 100644 --- a/lib/g3-types/src/acl/user_agent.rs +++ b/lib/g3-types/src/acl/user_agent.rs @@ -16,45 +16,44 @@ use std::collections::BTreeMap; -use super::AclAction; +use super::{AclAction, ActionContract}; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclUserAgentRule { - inner: BTreeMap, - missed_action: AclAction, +pub struct AclUserAgentRule { + inner: BTreeMap, + missed_action: Action, } -impl Default for AclUserAgentRule { +impl Default for AclUserAgentRule { fn default() -> Self { - // default to permit all - AclUserAgentRule::new(AclAction::Permit) + AclUserAgentRule::new(Action::default_permit()) } } -impl AclUserAgentRule { - pub fn new(missed_action: AclAction) -> Self { +impl AclUserAgentRule { + pub fn new(missed_action: Action) -> Self { AclUserAgentRule { inner: BTreeMap::new(), missed_action, } } - pub fn add_ua_name(&mut self, ua: &str, action: AclAction) { + pub fn add_ua_name(&mut self, ua: &str, action: Action) { let name = ua.to_ascii_lowercase(); self.inner.insert(name, action); } #[inline] - pub fn missed_action(&self) -> AclAction { + pub fn missed_action(&self) -> Action { self.missed_action } #[inline] - pub fn set_missed_action(&mut self, action: AclAction) { + pub fn set_missed_action(&mut self, action: Action) { self.missed_action = action; } - pub fn check(&self, ua_value: &str) -> (bool, AclAction) { + pub fn check(&self, ua_value: &str) -> (bool, Action) { let value = ua_value.to_ascii_lowercase(); for (name, action) in self.inner.iter() { diff --git a/lib/g3-types/src/acl_set/dst_host.rs b/lib/g3-types/src/acl_set/dst_host.rs index 771fedc2f..58d415d22 100644 --- a/lib/g3-types/src/acl_set/dst_host.rs +++ b/lib/g3-types/src/acl_set/dst_host.rs @@ -16,39 +16,52 @@ use crate::acl::{ AclAction, AclChildDomainRule, AclChildDomainRuleBuilder, AclExactHostRule, AclNetworkRule, - AclNetworkRuleBuilder, AclRegexSetRule, AclRegexSetRuleBuilder, + AclNetworkRuleBuilder, AclRegexSetRule, AclRegexSetRuleBuilder, ActionContract, }; use crate::net::Host; #[derive(Clone, Debug, Eq, PartialEq)] -pub struct AclDstHostRuleSetBuilder { - pub exact: Option, - pub child: Option, - pub regex: Option, - pub subnet: Option, +pub struct AclDstHostRuleSetBuilder { + pub exact: Option>, + pub child: Option>, + pub regex: Option>, + pub subnet: Option>, + pub missing_action: Option, } -impl AclDstHostRuleSetBuilder { - pub fn build(&self) -> AclDstHostRuleSet { - let mut missed_action = AclAction::Permit; +impl Default for AclDstHostRuleSetBuilder { + fn default() -> Self { + Self { + exact: None, + child: None, + regex: None, + subnet: None, + missing_action: None, + } + } +} + +impl AclDstHostRuleSetBuilder { + pub fn build(&self) -> AclDstHostRuleSet { + let mut missed_action = self.missing_action.unwrap_or_else(Action::default_permit); let exact_rule = self.exact.as_ref().map(|rule| { - missed_action = missed_action.restrict(rule.missed_action()); + missed_action = rule.missed_action().max(missed_action); rule.clone() }); let child_rule = self.child.as_ref().map(|builder| { - missed_action = missed_action.restrict(builder.missed_action()); + missed_action = builder.missed_action().max(missed_action); builder.build() }); let regex_rule = self.regex.as_ref().map(|builder| { - missed_action = missed_action.restrict(builder.missed_action()); + missed_action = builder.missed_action().max(missed_action); builder.build() }); let subnet_rule = self.subnet.as_ref().map(|builder| { - missed_action = missed_action.restrict(builder.missed_action()); + missed_action = builder.missed_action().max(missed_action); builder.build() }); @@ -62,16 +75,32 @@ impl AclDstHostRuleSetBuilder { } } -pub struct AclDstHostRuleSet { - exact: Option, - child: Option, - regex: Option, - subnet: Option, - missed_action: AclAction, +#[derive(Clone)] +pub struct AclDstHostRuleSet { + exact: Option>, + child: Option>, + regex: Option>, + subnet: Option>, + missed_action: Action, } -impl AclDstHostRuleSet { - pub fn check(&self, upstream: &Host) -> (bool, AclAction) { +impl AclDstHostRuleSet { + pub fn builder() -> AclDstHostRuleSetBuilder { + AclDstHostRuleSetBuilder::default() + } + + pub fn builder_with_missing_action(action: Action) -> AclDstHostRuleSetBuilder { + AclDstHostRuleSetBuilder { + missing_action: Some(action), + ..AclDstHostRuleSetBuilder::default() + } + } + + pub fn missing_action(&self) -> Action { + self.missed_action + } + + pub fn check(&self, upstream: &Host) -> (bool, Action) { match upstream { Host::Ip(ip) => { if let Some(rule) = &self.exact { diff --git a/lib/g3-yaml/src/value/acl/child_domain.rs b/lib/g3-yaml/src/value/acl/child_domain.rs index 849f88630..70d0562d6 100644 --- a/lib/g3-yaml/src/value/acl/child_domain.rs +++ b/lib/g3-yaml/src/value/acl/child_domain.rs @@ -17,22 +17,22 @@ use anyhow::anyhow; use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclChildDomainRuleBuilder}; +use g3_types::acl::{AclChildDomainRuleBuilder, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclChildDomainRuleBuilder { +impl AclRuleYamlParser for AclChildDomainRuleBuilder { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { match value { Yaml::String(_) => { let host = crate::value::as_domain(value)?; @@ -44,10 +44,10 @@ impl AclRuleYamlParser for AclChildDomainRuleBuilder { } } -pub(crate) fn as_child_domain_rule_builder( +pub(crate) fn as_child_domain_rule_builder( value: &Yaml, -) -> anyhow::Result { - let mut builder = AclChildDomainRuleBuilder::new(AclAction::Forbid); +) -> anyhow::Result> { + let mut builder = AclChildDomainRuleBuilder::new(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/exact_host.rs b/lib/g3-yaml/src/value/acl/exact_host.rs index deaac8116..cc892950a 100644 --- a/lib/g3-yaml/src/value/acl/exact_host.rs +++ b/lib/g3-yaml/src/value/acl/exact_host.rs @@ -16,30 +16,32 @@ use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclExactHostRule}; +use g3_types::acl::{AclExactHostRule, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclExactHostRule { +impl AclRuleYamlParser for AclExactHostRule { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { let host = crate::value::as_host(value)?; self.add_host(host, action); Ok(()) } } -pub(crate) fn as_exact_host_rule(value: &Yaml) -> anyhow::Result { - let mut builder = AclExactHostRule::new(AclAction::Forbid); +pub(crate) fn as_exact_host_rule( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclExactHostRule::new(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/exact_port.rs b/lib/g3-yaml/src/value/acl/exact_port.rs index 3684371f0..a321685f8 100644 --- a/lib/g3-yaml/src/value/acl/exact_port.rs +++ b/lib/g3-yaml/src/value/acl/exact_port.rs @@ -16,30 +16,32 @@ use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclExactPortRule}; +use g3_types::acl::{AclExactPortRule, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclExactPortRule { +impl AclRuleYamlParser for AclExactPortRule { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { let ports = crate::value::as_ports(value)?; self.add_ports(ports, action); Ok(()) } } -pub fn as_exact_port_rule(value: &Yaml) -> anyhow::Result { - let mut builder = AclExactPortRule::new(AclAction::Forbid); +pub fn as_exact_port_rule( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclExactPortRule::new(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/mod.rs b/lib/g3-yaml/src/value/acl/mod.rs index 213f5d85d..f2bd5d110 100644 --- a/lib/g3-yaml/src/value/acl/mod.rs +++ b/lib/g3-yaml/src/value/acl/mod.rs @@ -14,12 +14,10 @@ * limitations under the License. */ -use std::str::FromStr; - use anyhow::{anyhow, Context}; use yaml_rust::Yaml; -use g3_types::acl::AclAction; +use g3_types::acl::ActionContract; mod child_domain; mod exact_host; @@ -39,26 +37,21 @@ pub use network::{as_egress_network_rule_builder, as_ingress_network_rule_builde pub use proxy_request::as_proxy_request_rule; pub use user_agent::as_user_agent_rule; -fn as_action(value: &Yaml) -> anyhow::Result { +fn as_action(value: &Yaml) -> anyhow::Result { if let Yaml::String(s) = value { - let action = - AclAction::from_str(s).map_err(|_| anyhow!("invalid AclAction string value"))?; + let action = Action::deserialize(s).map_err(|_| anyhow!("invalid Action string value"))?; Ok(action) } else { - Err(anyhow!( - "the yaml value type for AclAction should be string" - )) + Err(anyhow!("the yaml value type for Action should be string")) } } -trait AclRuleYamlParser { - fn get_default_found_action(&self) -> AclAction; - fn set_missed_action(&mut self, action: AclAction); - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()>; +trait AclRuleYamlParser { + fn get_default_found_action(&self) -> Action; + fn set_missed_action(&mut self, action: Action); + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()>; fn parse(&mut self, value: &Yaml) -> anyhow::Result<()> { - let default_found_action = self.get_default_found_action(); - match value { Yaml::Hash(map) => { crate::foreach_kv(map, |k, v| match crate::key::normalize(k).as_str() { @@ -68,8 +61,8 @@ trait AclRuleYamlParser { Ok(()) } _ => { - let action = AclAction::from_str(k) - .map_err(|_| anyhow!("the key {k} is not a valid AclAction"))?; + let action = Action::deserialize(k) + .map_err(|_| anyhow!("the key {k} is not a valid Action"))?; if let Yaml::Array(seq) = v { for (i, v) in seq.iter().enumerate() { self.add_rule_for_action(action, v) @@ -85,12 +78,12 @@ trait AclRuleYamlParser { } Yaml::Array(seq) => { for (i, v) in seq.iter().enumerate() { - self.add_rule_for_action(default_found_action, v) + self.add_rule_for_action(self.get_default_found_action(), v) .context(format!("invalid value for element #{i}"))?; } } _ => { - self.add_rule_for_action(default_found_action, value)?; + self.add_rule_for_action(self.get_default_found_action(), value)?; } } Ok(()) diff --git a/lib/g3-yaml/src/value/acl/network.rs b/lib/g3-yaml/src/value/acl/network.rs index 2afac9482..5ac82dc79 100644 --- a/lib/g3-yaml/src/value/acl/network.rs +++ b/lib/g3-yaml/src/value/acl/network.rs @@ -16,42 +16,48 @@ use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclNetworkRuleBuilder}; +use g3_types::acl::{AclNetworkRuleBuilder, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclNetworkRuleBuilder { +impl AclRuleYamlParser for AclNetworkRuleBuilder { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { let net = crate::value::as_ip_network(value)?; self.add_network(net, action); Ok(()) } } -pub(crate) fn as_dst_subnet_rule_builder(value: &Yaml) -> anyhow::Result { - let mut builder = AclNetworkRuleBuilder::new_egress(AclAction::Forbid); +pub(crate) fn as_dst_subnet_rule_builder( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclNetworkRuleBuilder::new_egress(Action::default_forbid()); builder.parse(value)?; Ok(builder) } -pub fn as_egress_network_rule_builder(value: &Yaml) -> anyhow::Result { - let mut builder = AclNetworkRuleBuilder::new_egress(AclAction::Forbid); +pub fn as_egress_network_rule_builder( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclNetworkRuleBuilder::new_egress(Action::default_forbid()); builder.parse(value)?; Ok(builder) } -pub fn as_ingress_network_rule_builder(value: &Yaml) -> anyhow::Result { - let mut builder = AclNetworkRuleBuilder::new_ingress(AclAction::Forbid); +pub fn as_ingress_network_rule_builder( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclNetworkRuleBuilder::new_ingress(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/proxy_request.rs b/lib/g3-yaml/src/value/acl/proxy_request.rs index b36de82f9..2718e886e 100644 --- a/lib/g3-yaml/src/value/acl/proxy_request.rs +++ b/lib/g3-yaml/src/value/acl/proxy_request.rs @@ -16,30 +16,32 @@ use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclProxyRequestRule}; +use g3_types::acl::{AclProxyRequestRule, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclProxyRequestRule { +impl AclRuleYamlParser for AclProxyRequestRule { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { let t = crate::value::as_proxy_request_type(value)?; self.add_request_type(t, action); Ok(()) } } -pub fn as_proxy_request_rule(value: &Yaml) -> anyhow::Result { - let mut builder = AclProxyRequestRule::new(AclAction::Forbid); +pub fn as_proxy_request_rule( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclProxyRequestRule::new(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/regex_set.rs b/lib/g3-yaml/src/value/acl/regex_set.rs index 9b69aea92..7f22085dc 100644 --- a/lib/g3-yaml/src/value/acl/regex_set.rs +++ b/lib/g3-yaml/src/value/acl/regex_set.rs @@ -18,22 +18,22 @@ use anyhow::anyhow; use regex::Regex; use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclRegexSetRuleBuilder}; +use g3_types::acl::{AclRegexSetRuleBuilder, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclRegexSetRuleBuilder { +impl AclRuleYamlParser for AclRegexSetRuleBuilder { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Permit + fn get_default_found_action(&self) -> Action { + Action::default_permit() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { match value { Yaml::String(_) => { let regex = as_regex(value)?; @@ -56,8 +56,10 @@ fn as_regex(value: &Yaml) -> anyhow::Result { } } -pub(crate) fn as_regex_set_rule_builder(value: &Yaml) -> anyhow::Result { - let mut builder = AclRegexSetRuleBuilder::new(AclAction::Forbid); +pub(crate) fn as_regex_set_rule_builder( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclRegexSetRuleBuilder::new(Action::default_forbid()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl/user_agent.rs b/lib/g3-yaml/src/value/acl/user_agent.rs index 712fb9b7a..be4a22495 100644 --- a/lib/g3-yaml/src/value/acl/user_agent.rs +++ b/lib/g3-yaml/src/value/acl/user_agent.rs @@ -17,22 +17,22 @@ use anyhow::Context; use yaml_rust::Yaml; -use g3_types::acl::{AclAction, AclUserAgentRule}; +use g3_types::acl::{AclUserAgentRule, ActionContract}; use super::AclRuleYamlParser; -impl AclRuleYamlParser for AclUserAgentRule { +impl AclRuleYamlParser for AclUserAgentRule { #[inline] - fn get_default_found_action(&self) -> AclAction { - AclAction::Forbid + fn get_default_found_action(&self) -> Action { + Action::default_forbid() } #[inline] - fn set_missed_action(&mut self, _action: AclAction) { - self.set_missed_action(_action); + fn set_missed_action(&mut self, action: Action) { + self.set_missed_action(action); } - fn add_rule_for_action(&mut self, action: AclAction, value: &Yaml) -> anyhow::Result<()> { + fn add_rule_for_action(&mut self, action: Action, value: &Yaml) -> anyhow::Result<()> { let ua_name = crate::value::as_ascii(value) .context("user-agent name should be valid ascii string")?; self.add_ua_name(ua_name.as_str(), action); @@ -40,8 +40,10 @@ impl AclRuleYamlParser for AclUserAgentRule { } } -pub fn as_user_agent_rule(value: &Yaml) -> anyhow::Result { - let mut builder = AclUserAgentRule::new(AclAction::Permit); +pub fn as_user_agent_rule( + value: &Yaml, +) -> anyhow::Result> { + let mut builder = AclUserAgentRule::new(Action::default_permit()); builder.parse(value)?; Ok(builder) } diff --git a/lib/g3-yaml/src/value/acl_set/dst_host.rs b/lib/g3-yaml/src/value/acl_set/dst_host.rs index 4556910cc..645e0fc7d 100644 --- a/lib/g3-yaml/src/value/acl_set/dst_host.rs +++ b/lib/g3-yaml/src/value/acl_set/dst_host.rs @@ -17,16 +17,14 @@ use anyhow::{anyhow, Context}; use yaml_rust::Yaml; -use g3_types::acl_set::AclDstHostRuleSetBuilder; +use g3_types::{acl::ActionContract, acl_set::AclDstHostRuleSetBuilder}; -pub fn as_dst_host_rule_set_builder(value: &Yaml) -> anyhow::Result { +pub fn as_dst_host_rule_set_builder( + value: &Yaml, +) -> anyhow::Result> { if let Yaml::Hash(map) = value { - let mut builder = AclDstHostRuleSetBuilder { - exact: None, - child: None, - regex: None, - subnet: None, - }; + let mut builder = AclDstHostRuleSetBuilder::default(); + crate::foreach_kv(map, |k, v| match crate::key::normalize(k).as_str() { "exact_match" | "exact" => { let exact_rule = crate::value::acl::as_exact_host_rule(v) diff --git a/lib/g3-yaml/src/value/dpi/inspect.rs b/lib/g3-yaml/src/value/dpi/inspect.rs index a1f1b00a0..1714a1610 100644 --- a/lib/g3-yaml/src/value/dpi/inspect.rs +++ b/lib/g3-yaml/src/value/dpi/inspect.rs @@ -19,16 +19,25 @@ use std::str::FromStr; use anyhow::{anyhow, Context}; use yaml_rust::Yaml; -use g3_dpi::{ProtocolInspectPolicy, ProtocolInspectionConfig, ProtocolInspectionSizeLimit}; +use g3_dpi::{ + ProtocolInspectAction, ProtocolInspectPolicy, ProtocolInspectionConfig, + ProtocolInspectionSizeLimit, +}; + +use crate::value::acl_set::as_dst_host_rule_set_builder; pub fn as_protocol_inspect_policy(value: &Yaml) -> anyhow::Result { - if let Yaml::String(s) = value { - ProtocolInspectPolicy::from_str(s) - .map_err(|_| anyhow!("invalid protocol inspect policy '{s}'")) - } else { - Err(anyhow!( - "yaml value type for 'protocol inspect policy' should be 'string'" - )) + match value { + Yaml::String(s) => { + let missing_action = ProtocolInspectAction::from_str(s) + .map_err(|_| anyhow!("invalid protocol inspect action '{s}'"))?; + let mut builder = ProtocolInspectPolicy::builder(); + builder.missing_action = Some(missing_action); + Ok(builder.build()) + } + _ => as_dst_host_rule_set_builder(value) + .map(|b| b.build()) + .map_err(|err| anyhow!("invalid protocol inspect action map: {err:?}")), } }