Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
vrischmann committed Aug 20, 2024
1 parent 0c423e8 commit 8cbf4c1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 84 deletions.
6 changes: 3 additions & 3 deletions client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ pub const Client = struct {
);
}

const read_message = try self.connection.readMessage(allocator, .{
const read_message = try self.connection.nextMessage(allocator, .{
.message_allocator = self.allocator,
});
switch (read_message) {
Expand Down Expand Up @@ -284,7 +284,7 @@ pub const Client = struct {
}

// Read either RESULT or ERROR
return switch (try self.connection.readMessage(allocator, .{})) {
return switch (try self.connection.nextMessage(allocator, .{})) {
.result => |result_message| {
return switch (result_message.result) {
.Rows => |rows| blk: {
Expand Down Expand Up @@ -375,7 +375,7 @@ pub const Client = struct {
}

// Read either RESULT or ERROR
return switch (try self.connection.readMessage(allocator, .{})) {
return switch (try self.connection.nextMessage(allocator, .{})) {
.result => |result_message| {
return switch (result_message.result) {
.Rows => |rows| blk: {
Expand Down
110 changes: 76 additions & 34 deletions connection.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ const Frame = protocol.Frame;
const Envelope = protocol.Envelope;
const EnvelopeFlags = protocol.EnvelopeFlags;
const EnvelopeHeader = protocol.EnvelopeHeader;
const EnvelopeReader = protocol.EnvelopeReader;
const EnvelopeWriter = protocol.EnvelopeWriter;

const CQLVersion = protocol.CQLVersion;
Expand Down Expand Up @@ -90,7 +89,6 @@ pub const Connection = struct {
const BufferedWriterType = io.BufferedWriter(4096, std.net.Stream.Writer);

// TODO(vincent): we probably don't need the reader and writer abstractions here.
const EnvelopeReaderType = EnvelopeReader(BufferedReaderType.Reader);
const EnvelopeWriterType = EnvelopeWriter(std.ArrayList(u8).Writer);

/// Contains the state that is negotiated with a node as part of the handshake.
Expand All @@ -101,25 +99,23 @@ pub const Connection = struct {
allocator: mem.Allocator,
options: InitOptions,

framing: struct {
enabled: bool = false,
format: Frame.Format = undefined,
},

socket: std.net.Stream,

buffered_reader: BufferedReaderType,
buffered_writer: BufferedWriterType,

/// Helpers types needed to encode and decode the CQL protocol.
envelope_reader: EnvelopeReaderType,
envelope_writer_buffer: std.ArrayList(u8),
envelope_writer: EnvelopeWriterType,
message_reader: MessageReader,
message_writer: MessageWriter,

// Negotiated with the server
// connection state
negotiated_state: NegotiatedState,
framing: struct {
enabled: bool = false,
format: Frame.Format = undefined,
},

pub fn initIp4(self: *Self, allocator: mem.Allocator, seed_address: net.Address, options: InitOptions) !void {
self.allocator = allocator;
Expand All @@ -131,7 +127,6 @@ pub const Connection = struct {
self.buffered_reader = BufferedReaderType{ .unbuffered_reader = self.socket.reader() };
self.buffered_writer = BufferedWriterType{ .unbuffered_writer = self.socket.writer() };

self.envelope_reader = EnvelopeReaderType.init(self.buffered_reader.reader());
self.envelope_writer_buffer = std.ArrayList(u8).init(allocator);
self.envelope_writer = EnvelopeWriterType.init(self.envelope_writer_buffer.writer());

Expand Down Expand Up @@ -186,7 +181,7 @@ pub const Connection = struct {
);

fba.reset();
switch (try self.readMessage(fba.allocator(), .{})) {
switch (try self.nextMessage(fba.allocator(), .{})) {
.supported => |fr| self.negotiated_state.cql_version = fr.cql_versions[0],
.@"error" => |err| {
diags.message = err.message;
Expand All @@ -207,7 +202,7 @@ pub const Connection = struct {
},
.{},
);
switch (try self.readMessage(fba.allocator(), .{})) {
switch (try self.nextMessage(fba.allocator(), .{})) {
.ready => return,
.authenticate => |fr| {
try self.authenticate(fba.allocator(), diags, fr.authenticator);
Expand Down Expand Up @@ -246,7 +241,7 @@ pub const Connection = struct {
}

// Read either AUTH_CHALLENGE, AUTH_SUCCESS or ERROR
switch (try self.readMessage(allocator, .{})) {
switch (try self.nextMessage(allocator, .{})) {
.auth_challenge => unreachable,
.auth_success => return,
.@"error" => |err| {
Expand Down Expand Up @@ -362,42 +357,70 @@ pub const Connection = struct {

pub const ReadMessageOptions = struct {
/// If set the message will be allocated using this allocator instead of the allocator
/// passed in `readMessage`.
/// passed in `nextMessage`.
///
/// This is useful if you want the message to have a different lifecycle, for example
/// if you need to store the message in a list or a map.
message_allocator: ?mem.Allocator = null,
};

pub fn readMessage(self: *Self, allocator: mem.Allocator, options: ReadMessageOptions) !Message {
pub fn nextMessage(self: *Self, allocator: mem.Allocator, options: ReadMessageOptions) !Message {
var messages = try self.readMessages(allocator, options);
defer messages.deinit();

debug.assert(messages.items.len == 1);

return messages.pop();
}

fn readMessages(self: *Self, allocator: mem.Allocator, options: ReadMessageOptions) !std.ArrayList(Message) {
var result = std.ArrayList(Message).init(allocator);

if (self.options.protocol_version.isAtMost(4)) {
try self.readMessageV4(allocator, options, &result);
} else if (self.options.protocol_version.isAtLeast(5)) {
try self.readMessageV4(allocator, options, &result);
}

return result;
}

fn readMessagesV5(self: *Self, allocator: mem.Allocator, options: ReadMessageOptions, messages: *std.ArrayList(Message)) !void {
const buffer = blk: {
var buf: [256 * 1024]u8 = undefined;
const n = try self.buffered_reader.read(&buf);
if (n <= 0) {
return error.UnexpectedEOF;
}
break :blk buf[0..n];
};

const result = try Frame.decode(allocator, buffer, self.framing.format);
debug.assert(result.consumed == buffer.len);
debug.assert(result.frame.is_self_contained);
debug.assert(result.frame.payload.len > 0);

var fbs = io.StreamSource{ .const_buffer = io.fixedBufferStream(result.frame.payload) };
const envelope = try Envelope.read(allocator, fbs.reader());

const message = try self.decodeMessage(options.message_allocator orelse allocator, envelope);

try messages.append(message);
}

fn readMessageV4(self: *Self, allocator: mem.Allocator, options: ReadMessageOptions, messages: *std.ArrayList(Message)) !void {
const envelope = try self.readEnvelope(allocator);
defer envelope.deinit(allocator);

self.message_reader.reset(envelope.body);

const message_allocator = if (options.message_allocator) |message_allocator|
message_allocator
else
allocator;
const message = try self.decodeMessage(options.message_allocator orelse allocator, envelope);

return switch (envelope.header.opcode) {
.@"error" => Message{ .@"error" = try ErrorMessage.read(message_allocator, &self.message_reader) },
.startup => Message{ .startup = try StartupMessage.read(message_allocator, &self.message_reader) },
.ready => Message{ .ready = ReadyMessage{} },
.options => Message{ .options = {} },
.supported => Message{ .supported = try SupportedMessage.read(message_allocator, &self.message_reader) },
.result => Message{ .result = try ResultMessage.read(message_allocator, self.options.protocol_version, &self.message_reader) },
.register => Message{ .register = {} },
.event => Message{ .event = try EventMessage.read(message_allocator, &self.message_reader) },
.authenticate => Message{ .authenticate = try AuthenticateMessage.read(message_allocator, &self.message_reader) },
.auth_challenge => Message{ .auth_challenge = try AuthChallengeMessage.read(message_allocator, &self.message_reader) },
.auth_success => Message{ .auth_success = try AuthSuccessMessage.read(message_allocator, &self.message_reader) },
else => std.debug.panic("invalid read message {}\n", .{envelope.header.opcode}),
};
try messages.append(message);
}

fn readEnvelope(self: *Self, allocator: mem.Allocator) !Envelope {
var envelope = try self.envelope_reader.read(allocator);
var envelope = try Envelope.read(allocator, self.buffered_reader.reader());

if (envelope.header.flags & EnvelopeFlags.Compression == EnvelopeFlags.Compression) {
const compression = self.options.compression orelse return error.InvalidCompressedFrame;
Expand All @@ -417,4 +440,23 @@ pub const Connection = struct {

return envelope;
}

fn decodeMessage(self: *Self, message_allocator: mem.Allocator, envelope: Envelope) !Message {
const message = switch (envelope.header.opcode) {
.@"error" => Message{ .@"error" = try ErrorMessage.read(message_allocator, &self.message_reader) },
.startup => Message{ .startup = try StartupMessage.read(message_allocator, &self.message_reader) },
.ready => Message{ .ready = ReadyMessage{} },
.options => Message{ .options = {} },
.supported => Message{ .supported = try SupportedMessage.read(message_allocator, &self.message_reader) },
.result => Message{ .result = try ResultMessage.read(message_allocator, self.options.protocol_version, &self.message_reader) },
.register => Message{ .register = {} },
.event => Message{ .event = try EventMessage.read(message_allocator, &self.message_reader) },
.authenticate => Message{ .authenticate = try AuthenticateMessage.read(message_allocator, &self.message_reader) },
.auth_challenge => Message{ .auth_challenge = try AuthChallengeMessage.read(message_allocator, &self.message_reader) },
.auth_success => Message{ .auth_success = try AuthSuccessMessage.read(message_allocator, &self.message_reader) },
else => std.debug.panic("invalid read message {}\n", .{envelope.header.opcode}),
};

return message;
}
};
74 changes: 27 additions & 47 deletions protocol.zig
Original file line number Diff line number Diff line change
Expand Up @@ -556,51 +556,37 @@ pub const Envelope = struct {
pub fn deinit(self: @This(), allocator: mem.Allocator) void {
allocator.free(self.body);
}
};

pub fn EnvelopeReader(comptime ReaderType: type) type {
return struct {
const Self = @This();

reader: ReaderType,
pub fn read(allocator: mem.Allocator, reader: anytype) !Envelope {
var buf: [EnvelopeHeader.size]u8 = undefined;

pub fn init(in: ReaderType) Self {
return Self{
.reader = in,
};
const n_header_read = try reader.readAll(&buf);
if (n_header_read != EnvelopeHeader.size) {
return error.UnexpectedEOF;
}

pub fn read(self: *Self, allocator: mem.Allocator) !Envelope {
var buf: [EnvelopeHeader.size]u8 = undefined;

const n_header_read = try self.reader.readAll(&buf);
if (n_header_read != EnvelopeHeader.size) {
return error.UnexpectedEOF;
}

const header = EnvelopeHeader{
.version = ProtocolVersion{ .version = buf[0] },
.flags = buf[1],
.stream = mem.readInt(i16, @ptrCast(buf[2..4]), .big),
.opcode = @enumFromInt(buf[4]),
.body_len = mem.readInt(u32, @ptrCast(buf[5..9]), .big),
};
const header = EnvelopeHeader{
.version = ProtocolVersion{ .version = buf[0] },
.flags = buf[1],
.stream = mem.readInt(i16, @ptrCast(buf[2..4]), .big),
.opcode = @enumFromInt(buf[4]),
.body_len = mem.readInt(u32, @ptrCast(buf[5..9]), .big),
};

const len = @as(usize, header.body_len);
const len = @as(usize, header.body_len);

const body = try allocator.alloc(u8, len);
const n_read = try self.reader.readAll(body);
if (n_read != len) {
return error.UnexpectedEOF;
}

return Envelope{
.header = header,
.body = body,
};
const body = try allocator.alloc(u8, len);
const n_read = try reader.readAll(body);
if (n_read != len) {
return error.UnexpectedEOF;
}
};
}

return Envelope{
.header = header,
.body = body,
};
}
};

pub fn EnvelopeWriter(comptime WriterType: type) type {
return struct {
Expand Down Expand Up @@ -2099,11 +2085,7 @@ test "envelope header: read and write" {

// deserialize the header

const reader = fbs.reader();

var envelope_reader = EnvelopeReader(@TypeOf(reader)).init(reader);

const envelope = try envelope_reader.read(testing.allocator);
const envelope = try Envelope.read(testing.allocator, fbs.reader());
const header = envelope.header;

try testing.expect(header.version.is(4));
Expand Down Expand Up @@ -3771,13 +3753,11 @@ test "supported message" {

/// Reads an enevelope from the provided buffer.
/// Only intended to be used for tests.
fn testReadEnvelope(_allocator: mem.Allocator, data: []const u8) !Envelope {
fn testReadEnvelope(allocator: mem.Allocator, data: []const u8) !Envelope {
var source = io.StreamSource{ .const_buffer = io.fixedBufferStream(data) };
const reader = source.reader();

var fr = EnvelopeReader(@TypeOf(reader)).init(reader);

return fr.read(_allocator);
return Envelope.read(allocator, reader);
}

/// Reads a frame from the provided buffer.
Expand Down

0 comments on commit 8cbf4c1

Please sign in to comment.