diff --git a/pb-jelly-gen/proto/rust/extensions.proto b/pb-jelly-gen/proto/rust/extensions.proto index 7c498dc..60cac2c 100644 --- a/pb-jelly-gen/proto/rust/extensions.proto +++ b/pb-jelly-gen/proto/rust/extensions.proto @@ -33,15 +33,15 @@ extend google.protobuf.FieldOptions { // // It doesn't make sense to specify this option for a repeating field, as a // missing field always deserializes as an empty Vec. - // In proto3, this option is also ignored for primitive types, which are - // always non-nullable + // In proto3, this option is also ignored for non-`optional` primitive types, + // which are already non-nullable. // // Beware that Default may not make sense for all message types. In // particular, fields using `OneofOptions.nullable=false` or // `EnumOptions.err_if_default_or_unknown=true` will simply default to their // first variant, and will _not_ trigger a deserialization error. That // behaviour may change in the future (TODO). - optional bool nullable_field = 50008; + optional bool nullable_field = 50008 [default=true]; } extend google.protobuf.EnumOptions { diff --git a/pb-jelly-gen/src/codegen.rs b/pb-jelly-gen/src/codegen.rs index b375a54..cc91447 100644 --- a/pb-jelly-gen/src/codegen.rs +++ b/pb-jelly-gen/src/codegen.rs @@ -293,54 +293,55 @@ impl<'a> RustType<'a> { } } - fn default(&self, msg_name: &str) -> String { - if let Some(oneof) = self.oneof { - if oneof_nullable(oneof) { - return "None".to_string(); - } else { - return self.oneof_val(msg_name, "::std::default::Default::default()"); + fn proto2_default(&self) -> String { + if let Some(ref default_value) = self.field.default_value { + if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_STRING) { + return format!("\"{default_value}\".into()"); } - } - // Proto 3 doesn't have configurable default values. - if !self.is_proto3 { - if let Some(ref default_value) = self.field.default_value { - if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_STRING) { - return format!("Some(\"{default_value}\".into())"); - } + if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_BYTES) { + return format!("b\"{default_value}\".to_vec()"); + } - if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_BYTES) { - return format!("Some(b\"{default_value}\".to_vec())"); + if let Some(primitive) = self.field.r#type.and_then(get_primitive_type) { + let typ_name = primitive.rust_type; + if typ_name.contains("::pb") { + return format!("{typ_name}({default_value})"); } - - if let Some(primitive) = self.field.r#type.and_then(get_primitive_type) { - let typ_name = primitive.rust_type; - if typ_name.contains("::pb") { - return format!("Some({typ_name}({default_value}))"); - } - if typ_name.starts_with('f') && !default_value.contains('.') { - return format!("Some({default_value}.)"); - } - return format!("Some({default_value})"); - } - - if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_ENUM) { - let proto_type = self.ctx.find(self.field.get_type_name()); - let (crate_, mod_parts) = self.ctx.crate_from_proto_filename(self.proto_file.get_name()); - let value = format!( - "{}::{}", - proto_type.rust_name(self.ctx, &crate_, &mod_parts), - default_value - ); - return format!("Some({value})"); + if typ_name.starts_with('f') && !default_value.contains('.') { + return format!("{default_value}."); } + return format!("{default_value}"); + } - panic!( - "Default not supported on field {:?} of type {:?}", - self.field.get_name(), - self.field.r#type + if self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_ENUM) { + let proto_type = self.ctx.find(self.field.get_type_name()); + let (crate_, mod_parts) = self.ctx.crate_from_proto_filename(self.proto_file.get_name()); + let value = format!( + "{}::{}", + proto_type.rust_name(self.ctx, &crate_, &mod_parts), + default_value ); + return format!("{value}"); } + + panic!( + "Default not supported on field {:?} of type {:?}", + self.field.get_name(), + self.field.r#type + ); + } else { + "::std::default::Default::default()".to_string() + } + } + + fn default(&self) -> String { + assert!(self.oneof.is_none()); + + // Proto 3 doesn't have configurable default values. + if !self.is_proto3 && self.field.default_value.is_some() { + // TODO: this is incorrect; defaults are specified to be inserted at get-time. + return format!("::std::option::Option::Some({})", self.proto2_default()); } "::std::default::Default::default()".to_string() @@ -378,65 +379,41 @@ impl<'a> RustType<'a> { fn is_grpc_slices(&self) -> bool { self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_BYTES) - && self - .field - .get_options() - .get_extension(extensions::GRPC_SLICES) - .unwrap() - .unwrap_or(false) + && self.field.get_options().get_extension(extensions::GRPC_SLICES) } fn is_blob(&self) -> bool { self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_BYTES) - && self - .field - .get_options() - .get_extension(extensions::BLOB) - .unwrap() - .unwrap_or(false) + && self.field.get_options().get_extension(extensions::BLOB) } fn is_lazy_bytes(&self) -> bool { self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_BYTES) - && self - .field - .get_options() - .get_extension(extensions::ZERO_COPY) - .unwrap() - .unwrap_or(false) + && self.field.get_options().get_extension(extensions::ZERO_COPY) } fn is_small_string_optimization(&self) -> bool { self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_STRING) - && self - .field - .get_options() - .get_extension(extensions::SSO) - .unwrap() - .unwrap_or(false) + && self.field.get_options().get_extension(extensions::SSO) } fn is_boxed(&self) -> bool { self.field.r#type == Some(FieldDescriptorProto_Type::TYPE_MESSAGE) - && (self - .field - .get_options() - .get_extension(extensions::BOX_IT) - .unwrap() - .unwrap_or(false) + && (self.field.get_options().get_extension(extensions::BOX_IT) || self.ctx.implicitly_boxed.contains(&(self.field as *const _))) } fn has_custom_type(&self) -> bool { - self.field - .get_options() - .get_extension(extensions::TYPE) - .unwrap() - .is_some() + self.custom_type().is_some() } fn custom_type(&self) -> Option { - self.field.get_options().get_extension(extensions::TYPE).unwrap() + let ty = self.field.get_options().get_extension(extensions::TYPE); + if ty.is_empty() { + None + } else { + Some(ty) + } } fn is_nullable(&self) -> bool { @@ -453,7 +430,7 @@ impl<'a> RustType<'a> { if let Some(nullable_field) = self .field .get_options() - .get_extension(extensions::NULLABLE_FIELD) + .get_extension_opt(extensions::NULLABLE_FIELD) .unwrap() { // We still allow overriding nullability as an extension @@ -485,7 +462,7 @@ impl<'a> RustType<'a> { fn set_method(&self) -> (String, String) { assert!(!self.is_repeated()); - match self.field.r#type.unwrap() { + match self.field.get_type() { FieldDescriptorProto_Type::TYPE_FLOAT => ("f32".to_string(), "v".to_string()), FieldDescriptorProto_Type::TYPE_DOUBLE => ("f64".to_string(), "v".to_string()), FieldDescriptorProto_Type::TYPE_INT32 => ("i32".to_string(), "v".to_string()), @@ -578,6 +555,7 @@ impl<'a> RustType<'a> { assert!(!self.is_repeated()); let name = escape_name(self.field.get_name()); + // TODO: this does not respect default values match self.field.r#type { Some(FieldDescriptorProto_Type::TYPE_FLOAT) => ("f32".to_string(), format!("self.{name}.unwrap_or(0.)")), Some(FieldDescriptorProto_Type::TYPE_DOUBLE) => ("f64".to_string(), format!("self.{name}.unwrap_or(0.)")), @@ -720,27 +698,15 @@ fn oneof_msg_name(parent_msg_name: &str, oneof: &OneofDescriptorProto) -> String } fn oneof_nullable(oneof: &OneofDescriptorProto) -> bool { - oneof - .get_options() - .get_extension(extensions::NULLABLE) - .unwrap() - .unwrap_or(true) + oneof.get_options().get_extension(extensions::NULLABLE) } fn enum_err_if_default_or_unknown(enum_: &EnumDescriptorProto) -> bool { - enum_ - .get_options() - .get_extension(extensions::ERR_IF_DEFAULT_OR_UNKNOWN) - .unwrap() - .unwrap_or(false) + enum_.get_options().get_extension(extensions::ERR_IF_DEFAULT_OR_UNKNOWN) } fn enum_closed(enum_: &EnumDescriptorProto) -> bool { - enum_ - .get_options() - .get_extension(extensions::CLOSED_ENUM) - .unwrap() - .unwrap_or(false) + enum_.get_options().get_extension(extensions::CLOSED_ENUM) } fn block_with<'a, 'ctx>( @@ -882,11 +848,7 @@ impl<'a, 'ctx> CodeWriter<'a, 'ctx> { indentation: 0, content: String::new(), is_proto3: proto_file.get_syntax() == "proto3", - derive_serde: proto_file - .get_options() - .get_extension(extensions::SERDE_DERIVE) - .unwrap() - .unwrap_or(false), + derive_serde: proto_file.get_options().get_extension(extensions::SERDE_DERIVE), source_code_info_by_scl: proto_file .get_source_code_info() .location @@ -1191,11 +1153,7 @@ impl<'a, 'ctx> CodeWriter<'a, 'ctx> { assert_eq!(ctx.indentation, 0); let name = [path, &[msg_type.get_name()]].concat().join("_"); - let preserve_unrecognized = msg_type - .get_options() - .get_extension(extensions::PRESERVE_UNRECOGNIZED) - .unwrap() - == Some(true); + let preserve_unrecognized = msg_type.get_options().get_extension(extensions::PRESERVE_UNRECOGNIZED); let has_extensions = !msg_type.extension_range.is_empty(); let escaped_name = escape_name(&name); @@ -1397,13 +1355,22 @@ impl<'a, 'ctx> CodeWriter<'a, 'ctx> { for field in &msg_type.field { let typ = ctx.rust_type(Some(msg_type), field); if typ.oneof.is_none() { - ctx.write(format!("{}: {},", escape_name(field.get_name()), typ.default(&name))); + ctx.write(format!("{}: {},", escape_name(field.get_name()), typ.default())); } } for &oneof in &oneof_decls { - let oneof_field = oneof_fields[oneof.get_name()][0]; - let typ = ctx.rust_type(Some(msg_type), oneof_field); - ctx.write(format!("{}: {},", escape_name(oneof.get_name()), typ.default(&name))); + let default_value = if oneof_nullable(oneof) { + "None".into() + } else { + let oneof_field = oneof_fields[oneof.get_name()][0]; + let typ = ctx.rust_type(Some(msg_type), oneof_field); + typ.oneof_val(&name, "::std::default::Default::default()") + }; + + ctx.write(format!( + "{oneof_name}: {default_value},", + oneof_name = escape_name(oneof.get_name()) + )); } if preserve_unrecognized { ctx.write("_unrecognized: Vec::new(),"); @@ -1911,28 +1878,36 @@ buf, typ, ::pb_jelly::wire_format::Type::{expected_wire_format}, \"{msg_name}\", .join("_"); let rust_type = self.rust_type(None, extension_field); let extendee = self.ctx.find(extension_field.get_extendee()); - let kind = if extension_field.get_label() == FieldDescriptorProto_Label::LABEL_REPEATED { + let is_repeated = extension_field.get_label() == FieldDescriptorProto_Label::LABEL_REPEATED; + let kind = if is_repeated { "RepeatedExtension" } else { "SingularExtension" }; - self.write(format!( - "pub const {}: ::pb_jelly::extensions::{}<{}, {}> = - ::pb_jelly::extensions::{}::new( - {}, - ::pb_jelly::wire_format::Type::{}, - \"{}\", - );", - name, - kind, - extendee.rust_name(self.ctx, &crate_, &mod_parts), - rust_type.rust_type(), - kind, - extension_field.get_number(), - rust_type.wire_format(), - extension_field.get_name(), - )); + block_with( + self, + format!( + "pub const {name}: ::pb_jelly::extensions::{kind}<{extendee}, {field_type}> =", + extendee = extendee.rust_name(self.ctx, &crate_, &mod_parts), + field_type = rust_type.rust_type(), + ), + "", + "", + |ctx| { + block_with(ctx, format!("::pb_jelly::extensions::{kind}::new"), "(", ");", |ctx| { + ctx.write(format!("{field_number},", field_number = extension_field.get_number())); + ctx.write(format!( + "::pb_jelly::wire_format::Type::{wire_format},", + wire_format = rust_type.wire_format() + )); + ctx.write(format!("\"{field_name}\",", field_name = extension_field.get_name(),)); + if !is_repeated { + ctx.write(format!("|| {},", rust_type.proto2_default())); + } + }); + }, + ); } } @@ -2154,7 +2129,7 @@ struct Impls { } /// Given message types, keyed by their `proto_name()`s, detect recursive fields -/// that would otherwise cause an infinite-size type and add the `box_it` extension to them. +/// that would otherwise cause an infinite-size type and mark them as `implicitly_boxed`. fn box_recursive_fields( types: IndexMap>, implicitly_boxed: &mut IndexSet<*const FieldDescriptorProto>, @@ -2170,7 +2145,7 @@ fn box_recursive_fields( field.get_type() == FieldDescriptorProto_Type::TYPE_MESSAGE && types.contains_key(field.get_type_name()) && field.get_label() != FieldDescriptorProto_Label::LABEL_REPEATED - && field.get_options().get_extension(extensions::BOX_IT).unwrap() != Some(true) + && !field.get_options().get_extension(extensions::BOX_IT) }) .map(FieldDescriptorProto::get_type_name) .collect() @@ -2259,8 +2234,6 @@ impl<'a> Context<'a> { if descriptor .get_options() .get_extension(extensions::PRESERVE_UNRECOGNIZED) - .unwrap() - == Some(true) { impls_copy = false; // Preserve unparsed has a Vec which is not Copy } @@ -2339,17 +2312,13 @@ impl<'a> Context<'a> { if descriptor .get_options() .get_extension(extensions::PRESERVE_UNRECOGNIZED) - .unwrap() - == Some(true) { // TODO: this check isn't really necessary, but it is useful assert!( field_type .msg_typ() .get_options() - .get_extension(extensions::PRESERVE_UNRECOGNIZED) - .unwrap() - == Some(true), + .get_extension(extensions::PRESERVE_UNRECOGNIZED), "{} preserves unrecognized but child message {} does not", type_name, field_type.proto_name(), diff --git a/pb-jelly-gen/src/protos.rs b/pb-jelly-gen/src/protos.rs index df6d01a..0875357 100644 --- a/pb-jelly-gen/src/protos.rs +++ b/pb-jelly-gen/src/protos.rs @@ -5079,8 +5079,8 @@ pub mod google { input_type: ::std::default::Default::default(), output_type: ::std::default::Default::default(), options: ::std::default::Default::default(), - client_streaming: Some(false), - server_streaming: Some(false), + client_streaming: ::std::option::Option::Some(false), + server_streaming: ::std::option::Option::Some(false), } } } @@ -5656,17 +5656,17 @@ pub mod google { FileOptions { java_package: ::std::default::Default::default(), java_outer_classname: ::std::default::Default::default(), - java_multiple_files: Some(false), + java_multiple_files: ::std::option::Option::Some(false), java_generate_equals_and_hash: ::std::default::Default::default(), - java_string_check_utf8: Some(false), - optimize_for: Some(FileOptions_OptimizeMode::SPEED), + java_string_check_utf8: ::std::option::Option::Some(false), + optimize_for: ::std::option::Option::Some(FileOptions_OptimizeMode::SPEED), go_package: ::std::default::Default::default(), - cc_generic_services: Some(false), - java_generic_services: Some(false), - py_generic_services: Some(false), - php_generic_services: Some(false), - deprecated: Some(false), - cc_enable_arenas: Some(true), + cc_generic_services: ::std::option::Option::Some(false), + java_generic_services: ::std::option::Option::Some(false), + py_generic_services: ::std::option::Option::Some(false), + php_generic_services: ::std::option::Option::Some(false), + deprecated: ::std::option::Option::Some(false), + cc_enable_arenas: ::std::option::Option::Some(true), objc_class_prefix: ::std::default::Default::default(), csharp_namespace: ::std::default::Default::default(), swift_prefix: ::std::default::Default::default(), @@ -6447,9 +6447,9 @@ pub mod google { impl ::std::default::Default for MessageOptions { fn default() -> Self { MessageOptions { - message_set_wire_format: Some(false), - no_standard_descriptor_accessor: Some(false), - deprecated: Some(false), + message_set_wire_format: ::std::option::Option::Some(false), + no_standard_descriptor_accessor: ::std::option::Option::Some(false), + deprecated: ::std::option::Option::Some(false), map_entry: ::std::default::Default::default(), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), @@ -6789,12 +6789,12 @@ pub mod google { impl ::std::default::Default for FieldOptions { fn default() -> Self { FieldOptions { - ctype: Some(FieldOptions_CType::STRING), + ctype: ::std::option::Option::Some(FieldOptions_CType::STRING), packed: ::std::default::Default::default(), - jstype: Some(FieldOptions_JSType::JS_NORMAL), - lazy: Some(false), - deprecated: Some(false), - weak: Some(false), + jstype: ::std::option::Option::Some(FieldOptions_JSType::JS_NORMAL), + lazy: ::std::option::Option::Some(false), + deprecated: ::std::option::Option::Some(false), + weak: ::std::option::Option::Some(false), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), } @@ -7218,7 +7218,7 @@ pub mod google { fn default() -> Self { EnumOptions { allow_alias: ::std::default::Default::default(), - deprecated: Some(false), + deprecated: ::std::option::Option::Some(false), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), } @@ -7404,7 +7404,7 @@ pub mod google { impl ::std::default::Default for EnumValueOptions { fn default() -> Self { EnumValueOptions { - deprecated: Some(false), + deprecated: ::std::option::Option::Some(false), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), } @@ -7568,7 +7568,7 @@ pub mod google { impl ::std::default::Default for ServiceOptions { fn default() -> Self { ServiceOptions { - deprecated: Some(false), + deprecated: ::std::option::Option::Some(false), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), } @@ -7742,8 +7742,8 @@ pub mod google { impl ::std::default::Default for MethodOptions { fn default() -> Self { MethodOptions { - deprecated: Some(false), - idempotency_level: Some(MethodOptions_IdempotencyLevel::IDEMPOTENCY_UNKNOWN), + deprecated: ::std::option::Option::Some(false), + idempotency_level: ::std::option::Option::Some(MethodOptions_IdempotencyLevel::IDEMPOTENCY_UNKNOWN), uninterpreted_option: ::std::default::Default::default(), _extensions: ::pb_jelly::Unrecognized::default(), } @@ -9268,53 +9268,65 @@ pub mod rust { pub mod extensions { /// Generate this field in a Box as opposed to inline Option pub const BOX_IT: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50000, - ::pb_jelly::wire_format::Type::Varint, - "box_it", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50000, + ::pb_jelly::wire_format::Type::Varint, + "box_it", + || ::std::default::Default::default(), + ); + /// Generates a `Lazy` pub const GRPC_SLICES: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50003, - ::pb_jelly::wire_format::Type::Varint, - "grpc_slices", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50003, + ::pb_jelly::wire_format::Type::Varint, + "grpc_slices", + || ::std::default::Default::default(), + ); + /// Generates a `Lazy` pub const BLOB: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50010, - ::pb_jelly::wire_format::Type::Varint, - "blob", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50010, + ::pb_jelly::wire_format::Type::Varint, + "blob", + || ::std::default::Default::default(), + ); + /// Use a different Rust type which implements `pb::Message` to represent the field. /// All paths must be fully qualified, as in `::my_crate::full::path::to::type`. /// This only works with proto3. pub const TYPE: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50004, - ::pb_jelly::wire_format::Type::LengthDelimited, - "type", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50004, + ::pb_jelly::wire_format::Type::LengthDelimited, + "type", + || ::std::default::Default::default(), + ); + /// Generate this `bytes` field using a Lazy to enable zero-copy deserialization. pub const ZERO_COPY: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50007, - ::pb_jelly::wire_format::Type::Varint, - "zero_copy", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50007, + ::pb_jelly::wire_format::Type::Varint, + "zero_copy", + || ::std::default::Default::default(), + ); + /// Generate this `string` field using a type that supports a small string optimization. pub const SSO: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50009, - ::pb_jelly::wire_format::Type::Varint, - "sso", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50009, + ::pb_jelly::wire_format::Type::Varint, + "sso", + || ::std::default::Default::default(), + ); + /// If false, make this field's Rust type non-Optional. If the field is /// missing on the wire during deserialization, it will remain as @@ -9322,8 +9334,8 @@ pub mod rust { /// It doesn't make sense to specify this option for a repeating field, as a /// missing field always deserializes as an empty Vec. - /// In proto3, this option is also ignored for primitive types, which are - /// always non-nullable + /// In proto3, this option is also ignored for non-`optional` primitive types, + /// which are already non-nullable. /// Beware that Default may not make sense for all message types. In /// particular, fields using `OneofOptions.nullable=false` or @@ -9331,11 +9343,13 @@ pub mod rust { /// first variant, and will _not_ trigger a deserialization error. That /// behaviour may change in the future (TODO). pub const NULLABLE_FIELD: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50008, - ::pb_jelly::wire_format::Type::Varint, - "nullable_field", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50008, + ::pb_jelly::wire_format::Type::Varint, + "nullable_field", + || true, + ); + /// Setting this to true on an enum means the generated enum won't even have a value for the /// 0-value, and any message that would've parsed to having the value be 0 fail instead. @@ -9349,11 +9363,13 @@ pub mod rust { /// more desirable to just fail at parse time. If the client has updated *past* the server, it may /// send a value that the server does not know how to handle. We *also* fail this at parse time. pub const ERR_IF_DEFAULT_OR_UNKNOWN: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50002, - ::pb_jelly::wire_format::Type::Varint, - "err_if_default_or_unknown", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50002, + ::pb_jelly::wire_format::Type::Varint, + "err_if_default_or_unknown", + || ::std::default::Default::default(), + ); + /// Setting this to true means that an enum's variants are considered exhaustive. /// A Rust `enum` will be generated, rather than a wrapper around `i32`. This @@ -9366,11 +9382,13 @@ pub mod rust { /// values are still allowed. The two options are incompatible, as /// `err_if_default_or_unknown` is strictly stronger. pub const CLOSED_ENUM: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50008, - ::pb_jelly::wire_format::Type::Varint, - "closed_enum", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50008, + ::pb_jelly::wire_format::Type::Varint, + "closed_enum", + || ::std::default::Default::default(), + ); + /// Setting this to true adds an extra field to the deserialized message, which includes /// a serialized representation of unrecognized fields. @@ -9380,25 +9398,30 @@ pub mod rust { /// _unrecognized: Vec, /// } pub const PRESERVE_UNRECOGNIZED: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50006, - ::pb_jelly::wire_format::Type::Varint, - "preserve_unrecognized", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50006, + ::pb_jelly::wire_format::Type::Varint, + "preserve_unrecognized", + || ::std::default::Default::default(), + ); + /// If false, this oneof must have a field set. Parse error if no variant (or unrecognized /// variant) is set. pub const NULLABLE: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50001, - ::pb_jelly::wire_format::Type::Varint, - "nullable", - ); + ::pb_jelly::extensions::SingularExtension::new( + 50001, + ::pb_jelly::wire_format::Type::Varint, + "nullable", + || true, + ); + pub const SERDE_DERIVE: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 50005, - ::pb_jelly::wire_format::Type::Varint, - "serde_derive", - );} + ::pb_jelly::extensions::SingularExtension::new( + 50005, + ::pb_jelly::wire_format::Type::Varint, + "serde_derive", + || false, + );} } diff --git a/pb-jelly/src/extensions.rs b/pb-jelly/src/extensions.rs index 3aacb3f..048817c 100644 --- a/pb-jelly/src/extensions.rs +++ b/pb-jelly/src/extensions.rs @@ -13,10 +13,19 @@ use crate::{ /// Indicates that a message type has extension ranges defined. /// See for details. pub trait Extensible: Message { + /// Attempts to read the given extension field from `self`. + /// + /// If the field was not present, or any value failed to deserialize, returns the + /// default value for the extension field. This is either the declared default if + /// specified, or the empty value. + fn get_extension>(&self, extension: E) -> E::Value { + extension.get_or_default(self) + } + /// Attempts to read the given extension field from `self`. /// /// Returns `Err(_)` if the field was found but could not be deserialized as the declared field type. - fn get_extension>(&self, extension: E) -> io::Result { + fn get_extension_opt>(&self, extension: E) -> io::Result { extension.get(self) } @@ -30,8 +39,10 @@ pub trait Extensible: Message { /// Abstracts over [SingularExtension]/[RepeatedExtension]. pub trait Extension { type Extendee: Extensible; + type MaybeValue; type Value; - fn get(&self, m: &Self::Extendee) -> io::Result; + fn get(&self, m: &Self::Extendee) -> io::Result; + fn get_or_default(&self, m: &Self::Extendee) -> Self::Value; } /// An extension field. See for details. @@ -39,15 +50,17 @@ pub struct SingularExtension { pub field_number: u32, pub wire_format: wire_format::Type, pub name: &'static str, + pub default: fn() -> U, _phantom: PhantomData U>, } impl SingularExtension { - pub const fn new(field_number: u32, wire_format: wire_format::Type, name: &'static str) -> Self { + pub const fn new(field_number: u32, wire_format: wire_format::Type, name: &'static str, default: fn() -> U) -> Self { Self { field_number, wire_format, name, + default, _phantom: PhantomData, } } @@ -62,7 +75,8 @@ impl Clone for SingularExtension { impl Extension for SingularExtension { type Extendee = T; - type Value = Option; + type MaybeValue = Option; + type Value = U; fn get(&self, m: &Self::Extendee) -> io::Result> { Ok(match m._extensions().get_singular_field(self.field_number) { @@ -80,6 +94,10 @@ impl Extension for SingularExtension { None => None, }) } + + fn get_or_default(&self, m: &Self::Extendee) -> Self::Value { + self.get(m).ok().flatten().unwrap_or_else(self.default) + } } /// A `repeated` extension field. See for details. @@ -110,6 +128,7 @@ impl Clone for RepeatedExtension { impl Extension for RepeatedExtension { type Extendee = T; + type MaybeValue = Vec; type Value = Vec; fn get(&self, m: &Self::Extendee) -> io::Result> { @@ -130,4 +149,8 @@ impl Extension for RepeatedExtension { } Ok(result) } + + fn get_or_default(&self, m: &Self::Extendee) -> Self::Value { + self.get(m).unwrap_or_default() + } } diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected index 1424801..9eb2941 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected @@ -113,6 +113,7 @@ impl ::pb_jelly::extensions::Extensible for Msg { pub struct FakeMsg { pub base_field: ::std::option::Option, pub singular_primitive: ::std::option::Option, + pub singular_primitive_with_default: ::std::option::Option, pub singular_message: ::std::option::Option, pub repeated_primitive: ::std::vec::Vec, pub repeated_message: ::std::vec::Vec, @@ -136,6 +137,15 @@ impl FakeMsg { pub fn get_singular_primitive(&self) -> i32 { self.singular_primitive.unwrap_or(0) } + pub fn has_singular_primitive_with_default(&self) -> bool { + self.singular_primitive_with_default.is_some() + } + pub fn set_singular_primitive_with_default(&mut self, v: i32) { + self.singular_primitive_with_default = Some(v); + } + pub fn get_singular_primitive_with_default(&self) -> i32 { + self.singular_primitive_with_default.unwrap_or(0) + } pub fn has_singular_message(&self) -> bool { self.singular_message.is_some() } @@ -178,6 +188,7 @@ impl ::std::default::Default for FakeMsg { FakeMsg { base_field: ::std::default::Default::default(), singular_primitive: ::std::default::Default::default(), + singular_primitive_with_default: ::std::default::Default::default(), singular_message: ::std::default::Default::default(), repeated_primitive: ::std::default::Default::default(), repeated_message: ::std::default::Default::default(), @@ -211,10 +222,19 @@ impl ::pb_jelly::Message for FakeMsg { label: ::pb_jelly::Label::Optional, oneof_index: None, }, + ::pb_jelly::FieldDescriptor { + name: "singular_primitive_with_default", + full_name: "pbtest.FakeMsg.singular_primitive_with_default", + index: 2, + number: 102, + typ: ::pb_jelly::wire_format::Type::Varint, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, ::pb_jelly::FieldDescriptor { name: "singular_message", full_name: "pbtest.FakeMsg.singular_message", - index: 2, + index: 3, number: 301, typ: ::pb_jelly::wire_format::Type::LengthDelimited, label: ::pb_jelly::Label::Optional, @@ -223,7 +243,7 @@ impl ::pb_jelly::Message for FakeMsg { ::pb_jelly::FieldDescriptor { name: "repeated_primitive", full_name: "pbtest.FakeMsg.repeated_primitive", - index: 3, + index: 4, number: 300, typ: ::pb_jelly::wire_format::Type::Varint, label: ::pb_jelly::Label::Repeated, @@ -232,7 +252,7 @@ impl ::pb_jelly::Message for FakeMsg { ::pb_jelly::FieldDescriptor { name: "repeated_message", full_name: "pbtest.FakeMsg.repeated_message", - index: 4, + index: 5, number: 200, typ: ::pb_jelly::wire_format::Type::LengthDelimited, label: ::pb_jelly::Label::Repeated, @@ -259,6 +279,13 @@ impl ::pb_jelly::Message for FakeMsg { singular_primitive_size += l; } size += singular_primitive_size; + let mut singular_primitive_with_default_size = 0; + if let Some(ref val) = self.singular_primitive_with_default { + let l = ::pb_jelly::Message::compute_size(val); + singular_primitive_with_default_size += ::pb_jelly::wire_format::serialized_length(102); + singular_primitive_with_default_size += l; + } + size += singular_primitive_with_default_size; let mut singular_message_size = 0; if let Some(ref val) = self.singular_message { let l = ::pb_jelly::Message::compute_size(val); @@ -289,6 +316,10 @@ impl ::pb_jelly::Message for FakeMsg { ::pb_jelly::wire_format::write(101, ::pb_jelly::wire_format::Type::Varint, w)?; ::pb_jelly::Message::serialize(val, w)?; } + if let Some(ref val) = self.singular_primitive_with_default { + ::pb_jelly::wire_format::write(102, ::pb_jelly::wire_format::Type::Varint, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } for val in &self.repeated_message { ::pb_jelly::wire_format::write(200, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; let l = ::pb_jelly::Message::compute_size(val); @@ -322,6 +353,10 @@ impl ::pb_jelly::Message for FakeMsg { let val = ::pb_jelly::helpers::deserialize_known_length::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "FakeMsg", 101)?; self.singular_primitive = Some(val); } + 102 => { + let val = ::pb_jelly::helpers::deserialize_known_length::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "FakeMsg", 102)?; + self.singular_primitive_with_default = Some(val); + } 301 => { let val = ::pb_jelly::helpers::deserialize_length_delimited::(buf, typ, "FakeMsg", 301)?; self.singular_message = Some(val); @@ -357,6 +392,9 @@ impl ::pb_jelly::Reflection for FakeMsg { "singular_primitive" => { ::pb_jelly::reflection::FieldMut::Value(self.singular_primitive.get_or_insert_with(::std::default::Default::default)) } + "singular_primitive_with_default" => { + ::pb_jelly::reflection::FieldMut::Value(self.singular_primitive_with_default.get_or_insert_with(::std::default::Default::default)) + } "singular_message" => { ::pb_jelly::reflection::FieldMut::Value(self.singular_message.get_or_insert_with(::std::default::Default::default)) } @@ -374,30 +412,45 @@ impl ::pb_jelly::Reflection for FakeMsg { } pub const SINGULAR_PRIMITIVE: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 101, - ::pb_jelly::wire_format::Type::Varint, - "singular_primitive", - ); + ::pb_jelly::extensions::SingularExtension::new( + 101, + ::pb_jelly::wire_format::Type::Varint, + "singular_primitive", + || ::std::default::Default::default(), + ); + + +pub const SINGULAR_PRIMITIVE_WITH_DEFAULT: ::pb_jelly::extensions::SingularExtension = + ::pb_jelly::extensions::SingularExtension::new( + 102, + ::pb_jelly::wire_format::Type::Varint, + "singular_primitive_with_default", + || 123, + ); + pub const SINGULAR_MESSAGE: ::pb_jelly::extensions::SingularExtension = - ::pb_jelly::extensions::SingularExtension::new( - 301, - ::pb_jelly::wire_format::Type::LengthDelimited, - "singular_message", - ); + ::pb_jelly::extensions::SingularExtension::new( + 301, + ::pb_jelly::wire_format::Type::LengthDelimited, + "singular_message", + || ::std::default::Default::default(), + ); + pub const REPEATED_PRIMITIVE: ::pb_jelly::extensions::RepeatedExtension = - ::pb_jelly::extensions::RepeatedExtension::new( - 300, - ::pb_jelly::wire_format::Type::Varint, - "repeated_primitive", - ); + ::pb_jelly::extensions::RepeatedExtension::new( + 300, + ::pb_jelly::wire_format::Type::Varint, + "repeated_primitive", + ); + pub const REPEATED_MESSAGE: ::pb_jelly::extensions::RepeatedExtension = - ::pb_jelly::extensions::RepeatedExtension::new( - 200, - ::pb_jelly::wire_format::Type::LengthDelimited, - "repeated_message", - ); + ::pb_jelly::extensions::RepeatedExtension::new( + 200, + ::pb_jelly::wire_format::Type::LengthDelimited, + "repeated_message", + ); + diff --git a/pb-test/proto/includes/rust/extensions.proto b/pb-test/proto/includes/rust/extensions.proto deleted file mode 100644 index 7c498dc..0000000 --- a/pb-test/proto/includes/rust/extensions.proto +++ /dev/null @@ -1,93 +0,0 @@ -// Protocol Buffers for Rust with Gadgets -// -// Note: While proto3 still supports extensions for custom fields (https://github.com/protocolbuffers/protobuf/issues/1460) -// it does not support default values. Unfortunately the OneOfOptions::nullable field has a default value of true, while -// in proto3 booleans have a default value of false. We can't migrate this file to proto3 without deprecating nullable -// and changing its semantics. - -syntax = "proto2"; -package rust; - -import "google/protobuf/descriptor.proto"; - -extend google.protobuf.FieldOptions { - // Generate this field in a Box as opposed to inline Option - optional bool box_it = 50000; - // Generates a `Lazy` - optional bool grpc_slices = 50003; - // Generates a `Lazy` - optional bool blob = 50010; - // Use a different Rust type which implements `pb::Message` to represent the field. - // All paths must be fully qualified, as in `::my_crate::full::path::to::type`. - // This only works with proto3. - optional string type = 50004; - - // Generate this `bytes` field using a Lazy to enable zero-copy deserialization. - optional bool zero_copy = 50007; - // Generate this `string` field using a type that supports a small string optimization. - optional bool sso = 50009; - - // If false, make this field's Rust type non-Optional. If the field is - // missing on the wire during deserialization, it will remain as - // Default::default(). - // - // It doesn't make sense to specify this option for a repeating field, as a - // missing field always deserializes as an empty Vec. - // In proto3, this option is also ignored for primitive types, which are - // always non-nullable - // - // Beware that Default may not make sense for all message types. In - // particular, fields using `OneofOptions.nullable=false` or - // `EnumOptions.err_if_default_or_unknown=true` will simply default to their - // first variant, and will _not_ trigger a deserialization error. That - // behaviour may change in the future (TODO). - optional bool nullable_field = 50008; -} - -extend google.protobuf.EnumOptions { - // Setting this to true on an enum means the generated enum won't even have a value for the - // 0-value, and any message that would've parsed to having the value be 0 fail instead. - // - // If an enum field doesn't appear in the wire format of proto3, the 0 value is assumed. So if - // an enum field is added to a message in use between a client and server, and the client hasn't - // been recompiled, then all received messages on the server side will get the 0 value. As such, - // it's common in cases when there's not an obvious (and safe) default value to make the 0 value - // an explicit unknown/invalid value. In those cases, they become cumbersome to use in Rust - // because match statements will always require a branch for the unknown/invalid case, but it's - // more desirable to just fail at parse time. If the client has updated *past* the server, it may - // send a value that the server does not know how to handle. We *also* fail this at parse time. - optional bool err_if_default_or_unknown = 50002; - - // Setting this to true means that an enum's variants are considered exhaustive. - // A Rust `enum` will be generated, rather than a wrapper around `i32`. This - // makes matching against the enum simpler as there is no need to match - // unknown variants. Instead, deserialization will fail if an unknown - // variant is found over the wire. That makes adding or removing variants - // potentially unsafe when it comes to version skew. - // - // This option differs from `err_if_default_or_unknown` because default - // values are still allowed. The two options are incompatible, as - // `err_if_default_or_unknown` is strictly stronger. - optional bool closed_enum = 50008; -} - -extend google.protobuf.MessageOptions { - // Setting this to true adds an extra field to the deserialized message, which includes - // a serialized representation of unrecognized fields. - // Eg. - // MyMessage { - // field: u32, - // _unrecognized: Vec, - // } - optional bool preserve_unrecognized = 50006; -} - -extend google.protobuf.OneofOptions { - // If false, this oneof must have a field set. Parse error if no variant (or unrecognized - // variant) is set. - optional bool nullable = 50001 [default=true]; -} - -extend google.protobuf.FileOptions { - optional bool serde_derive = 50005 [default=false]; -} diff --git a/pb-test/proto/packages/pbtest/extensions.proto b/pb-test/proto/packages/pbtest/extensions.proto index 8e22963..213a573 100644 --- a/pb-test/proto/packages/pbtest/extensions.proto +++ b/pb-test/proto/packages/pbtest/extensions.proto @@ -11,6 +11,7 @@ message Msg { extend Msg { optional int32 singular_primitive = 101; + optional int32 singular_primitive_with_default = 102 [default=123]; optional ForeignMessage3 singular_message = 301; repeated int32 repeated_primitive = 300; repeated ForeignMessage3 repeated_message = 200; @@ -20,6 +21,7 @@ message FakeMsg { optional int32 base_field = 250; optional int32 singular_primitive = 101; + optional int32 singular_primitive_with_default = 102; optional ForeignMessage3 singular_message = 301; repeated int32 repeated_primitive = 300; repeated ForeignMessage3 repeated_message = 200; diff --git a/pb-test/src/pbtest.rs b/pb-test/src/pbtest.rs index 67d7bb9..8608bd4 100644 --- a/pb-test/src/pbtest.rs +++ b/pb-test/src/pbtest.rs @@ -946,35 +946,53 @@ fn test_mutual_recursion() { #[test] fn test_extensions() { - check_roundtrip(extensions::FakeMsg::default()); - check_roundtrip(extensions::FakeMsg { + let default_msg = check_roundtrip(extensions::FakeMsg::default()); + assert_eq!(default_msg.get_extension(extensions::SINGULAR_PRIMITIVE), 0); + assert_eq!( + default_msg.get_extension(extensions::SINGULAR_PRIMITIVE_WITH_DEFAULT), + 123 + ); + + let defined_msg = check_roundtrip(extensions::FakeMsg { base_field: Some(39), singular_primitive: Some(123), + singular_primitive_with_default: Some(1234), singular_message: Some(ForeignMessage3 { c: 321 }), repeated_primitive: vec![456, 789], repeated_message: vec![ForeignMessage3 { c: 654 }, ForeignMessage3 { c: 987 }], }); + assert_eq!(defined_msg.get_extension(extensions::SINGULAR_PRIMITIVE), 123); + assert_eq!( + defined_msg.get_extension(extensions::SINGULAR_PRIMITIVE_WITH_DEFAULT), + 1234 + ); // Check that serializing a FakeMsg and deserializing into Msg preserves the extension fields, // and that those fields can be read using `get_extension()`. - fn check_roundtrip(orig: extensions::FakeMsg) { + fn check_roundtrip(orig: extensions::FakeMsg) -> extensions::Msg { let m = extensions::Msg::deserialize_from_slice(&orig.serialize_to_vec()).unwrap(); assert_eq!(m.base_field, orig.base_field); assert_eq!( - m.get_extension(extensions::SINGULAR_PRIMITIVE).unwrap(), + m.get_extension_opt(extensions::SINGULAR_PRIMITIVE).unwrap(), orig.singular_primitive ); assert_eq!( - m.get_extension(extensions::SINGULAR_MESSAGE).unwrap(), + m.get_extension_opt(extensions::SINGULAR_PRIMITIVE_WITH_DEFAULT) + .unwrap(), + orig.singular_primitive_with_default + ); + assert_eq!( + m.get_extension_opt(extensions::SINGULAR_MESSAGE).unwrap(), orig.singular_message, ); assert_eq!( - m.get_extension(extensions::REPEATED_PRIMITIVE).unwrap(), + m.get_extension_opt(extensions::REPEATED_PRIMITIVE).unwrap(), orig.repeated_primitive, ); assert_eq!( - m.get_extension(extensions::REPEATED_MESSAGE).unwrap(), + m.get_extension_opt(extensions::REPEATED_MESSAGE).unwrap(), orig.repeated_message ); + m } }