diff --git a/protocol.zig b/protocol.zig index 32bec9b..eef0542 100644 --- a/protocol.zig +++ b/protocol.zig @@ -23,11 +23,6 @@ const QueryParameters = @import("QueryParameters.zig"); // // See ยง2 in native_protocol_v5.spec. -pub const FrameFormat = enum { - compressed, - uncompressed, -}; - pub const Frame = struct { /// Indicates that the payload includes one or more complete envelopes and can be fully processed immediately. is_self_contained: bool, @@ -36,10 +31,22 @@ pub const Frame = struct { /// * if not self contained, the payload is one part of a large envelope payload: []const u8, + const uncompressed_header_size = 6; + const compressed_header_size = 8; const trailer_size = 4; const max_payload_size = 131071; const initial_crc32_bytes = "\xFA\x2D\x55\xCA"; + pub const Format = enum { + compressed, + uncompressed, + }; + + pub const Containment = enum { + self_contained, + not_self_contained, + }; + fn crc24(input: anytype) usize { // Don't use @sizeOf because it contains a padding byte const Size = @typeInfo(@TypeOf(input)).Int.bits / 8; @@ -86,7 +93,7 @@ pub const Frame = struct { /// /// If there's not enough data or if the input data is corrupted somehow, an error is returned. /// Otherwise a result containing both the frame and the number of bytes consumed is returned. - pub fn read(allocator: mem.Allocator, data: []const u8, format: FrameFormat) Frame.ReadError!ReadResult { + pub fn read(allocator: mem.Allocator, data: []const u8, format: Format) Frame.ReadError!ReadResult { // TODO(vincent): do we really need the reader abstraction here ? var source = io.StreamSource{ .const_buffer = io.fixedBufferStream(data) }; const reader = source.reader(); @@ -142,17 +149,15 @@ pub const Frame = struct { } fn readUncompressed(allocator: mem.Allocator, reader: anytype) Frame.ReadError!ReadResult { - const header_size = 6; - // Read and parse header - const header = try readHeader(reader, header_size); + const header = try readHeader(reader, uncompressed_header_size); const first3b: u24 = mem.readInt(u24, header[0..3], .little); const payload_length: u17 = @intCast(first3b & 0x1FFFF); const is_self_contained = first3b & (1 << 17) != 0; const computed_crc24 = crc24(first3b); - const expected_crc24 = @as(usize, @intCast(mem.readInt(u24, header[3..header_size], .little))); + const expected_crc24 = @as(usize, @intCast(mem.readInt(u24, header[3..uncompressed_header_size], .little))); if (computed_crc24 != expected_crc24) return error.InvalidHeaderChecksum; // Read payload and trailer @@ -166,15 +171,13 @@ pub const Frame = struct { .payload = payload_and_trailer.payload, .is_self_contained = is_self_contained, }, - .consumed = header_size + payload_and_trailer.length(), + .consumed = uncompressed_header_size + payload_and_trailer.length(), }; } fn readCompressed(allocator: mem.Allocator, reader: anytype) Frame.ReadError!ReadResult { - const header_size = 8; - // Read and parse header - const header = try readHeader(reader, header_size); + const header = try readHeader(reader, compressed_header_size); const first5b: u40 = mem.readInt(u40, header[0..5], .little); const compressed_length: u17 = @intCast(first5b & 0x1FFFF); @@ -182,7 +185,7 @@ pub const Frame = struct { const is_self_contained = first5b & (1 << 34) != 0; const computed_crc24 = crc24(first5b); - const expected_crc24 = @as(usize, @intCast(mem.readInt(u24, header[5..header_size], .little))); + const expected_crc24 = @as(usize, @intCast(mem.readInt(u24, header[5..compressed_header_size], .little))); if (computed_crc24 != expected_crc24) return error.InvalidHeaderChecksum; // Read payload and trailer @@ -202,9 +205,51 @@ pub const Frame = struct { .payload = payload, .is_self_contained = is_self_contained, }, - .consumed = header_size + payload_and_trailer.length(), + .consumed = compressed_header_size + payload_and_trailer.length(), }; } + + const EncodeError = error{ + PayloadTooBig, + } || mem.Allocator.Error; + + fn encode(allocator: mem.Allocator, payload: []const u8, containment: Containment, format: Format) Frame.EncodeError!Frame { + switch (format) { + .uncompressed => return encodeUncompressed(allocator, payload, containment), + .compressed => return encodeCompressed(allocator, payload, containment), + } + } + + fn encodeUncompressed(allocator: mem.Allocator, payload: []const u8, containment: Containment) Frame.EncodeError!Frame { + if (payload.len > max_payload_size) return error.PayloadTooBig; + + // Create header + var header3b: u24 = @as(u17, @intCast(payload.len)); + if (containment == .self_contained) { + header3b |= 1 << 17; + } + const header_crc24 = crc24(header3b); + + var buf: [uncompressed_header_size]u8 = undefined; + mem.writeInt(u24, buf[0..3], header3b, .little); + mem.writeInt(u24, buf[3..uncompressed_header_size], @truncate(header_crc24), .little); + + std.debug.print("header3b: {d}, crc: {d}, buf: {s}\n", .{ + header3b, + header_crc24, + std.fmt.fmtSliceHexLower(&buf), + }); + + _ = allocator; + return error.OutOfMemory; + } + + fn encodeCompressed(allocator: mem.Allocator, payload: []const u8, containment: Containment) Frame.EncodeError!Frame { + _ = allocator; + _ = payload; + _ = containment; + return error.OutOfMemory; + } }; test "frame reader: QUERY message" { @@ -213,7 +258,7 @@ test "frame reader: QUERY message" { const TestCase = struct { data: []const u8, - format: FrameFormat, + format: Frame.Format, }; const testCases = &[_]TestCase{ @@ -265,7 +310,7 @@ test "frame reader: QUERY message incomplete" { const frame_data = @embedFile("testdata/query_frame_compressed.bin"); const test_data = frame_data ++ [_]u8{'z'} ** 2000; - const test_format: FrameFormat = .compressed; + const test_format: Frame.Format = .compressed; const tmp1 = Frame.read(arena.allocator(), test_data[0..1], test_format); try testing.expectError(error.UnexpectedEOF, tmp1); @@ -402,6 +447,50 @@ test "frame reader: RESULT message" { } } +test "frame write: PREPARE message" { + var arena = testutils.arenaAllocator(); + defer arena.deinit(); + + const protocol_version = try ProtocolVersion.init(5); + + // Write the message to a buffer + + const message = PrepareMessage{ + .query = "SELECT 1 FROM foobar", + .keyspace = null, + }; + + var mw = try MessageWriter.init(arena.allocator()); + try message.write(protocol_version, &mw); + + // Create and write the envelope to a buffer + + const envelope = Envelope{ + .header = EnvelopeHeader{ + .version = protocol_version, + .flags = 0, + .stream = 0, + .opcode = .prepare, + .body_len = 0, + }, + .body = mw.getWritten(), + }; + + var envelope_writer_buffer = std.ArrayList(u8).init(arena.allocator()); + var envelope_writer = EnvelopeWriter(std.ArrayList(u8).Writer).init(envelope_writer_buffer.writer()); + + try envelope_writer.write(envelope); + + // Write the frame to a buffer + + const frame_payload = try envelope_writer_buffer.toOwnedSlice(); + const frame = try Frame.encode(arena.allocator(), frame_payload, .self_contained, .uncompressed); + + std.debug.print("frame: {s}\n", .{ + std.fmt.fmtSliceHexLower(frame.payload), + }); +} + // // //