Skip to content

Commit

Permalink
Propagate SSLSession in ConnectionContext to enable SASL/SCRAM channe…
Browse files Browse the repository at this point in the history
…l binding.

[#645]
  • Loading branch information
mp911de committed Apr 8, 2024
1 parent 26761e8 commit 83d39c8
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 19 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@
<artifactId>scram-client</artifactId>
<version>${scram-client.version}</version>
</dependency>
<dependency>
<groupId>com.ongres.scram</groupId>
<artifactId>scram-common</artifactId>
<version>${scram-client.version}</version>
</dependency>
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler;
import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler;
import io.r2dbc.postgresql.client.Client;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.client.ConnectionSettings;
import io.r2dbc.postgresql.client.PostgresStartupParameterProvider;
import io.r2dbc.postgresql.client.StartupMessageFlow;
Expand All @@ -46,7 +47,7 @@ public Mono<Client> connect(SocketAddress endpoint, ConnectionSettings settings)

return this.upstreamFunction.connect(endpoint, settings)
.delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow
.exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(),
.exchange(auth -> getAuthenticationHandler(auth, credentials, client.getContext()), client, this.configuration.getDatabase(), credentials.getUsername(),
getParameterProvider(this.configuration, settings)))
.handle(ExceptionFactory.INSTANCE::handleErrorResponse));
}
Expand All @@ -55,13 +56,13 @@ private static PostgresStartupParameterProvider getParameterProvider(PostgresqlC
return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings);
}

protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) {
protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword, ConnectionContext context) {
if (PasswordAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername());
} else if (SASLAuthenticationHandler.supports(message)) {
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername());
return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername(), context);
} else {
throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import com.ongres.scram.client.ScramClient;
import com.ongres.scram.common.StringPreparation;
import com.ongres.scram.common.exception.ScramException;

import com.ongres.scram.common.util.TlsServerEndpoint;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
Expand All @@ -14,26 +15,40 @@
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.ByteBufferUtils;
import reactor.core.Exceptions;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

public class SASLAuthenticationHandler implements AuthenticationHandler {

private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);

private final CharSequence password;

private final String username;

private final ConnectionContext context;

private ScramClient scramClient;

/**
* Create a new handler.
*
* @param password the password to use for authentication
* @param username the username to use for authentication
* @param context the connection context
* @throws IllegalArgumentException if {@code password} or {@code user} is {@code null}
*/
public SASLAuthenticationHandler(CharSequence password, String username) {
public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) {
this.password = Assert.requireNonNull(password, "password must not be null");
this.username = Assert.requireNonNull(username, "username must not be null");
this.context = Assert.requireNonNull(context, "context must not be null");
}

/**
Expand Down Expand Up @@ -67,14 +82,44 @@ public FrontendMessage handle(AuthenticationMessage message) {
}

private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
this.scramClient = ScramClient.builder()

char[] password = new char[this.password.length()];
for (int i = 0; i < password.length; i++) {
password[i] = this.password.charAt(i);
}

ScramClient.FinalBuildStage builder = ScramClient.builder()
.advertisedMechanisms(message.getAuthenticationMechanisms())
.username(username) // ignored by the server, use startup message
.password(password.toString().toCharArray())
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION)
.build();
.username(this.username) // ignored by the server, use startup message
.password(password)
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);

SSLSession sslSession = this.context.getSslSession();

return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName());
if (sslSession != null && sslSession.isValid()) {
builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
}

this.scramClient = builder.build();

return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName());
}

private static byte[] extractSslEndpoint(SSLSession sslSession) {
try {
Certificate[] certificates = sslSession.getPeerCertificates();
if (certificates != null && certificates.length > 0) {
Certificate peerCert = certificates[0]; // First certificate is the peer's certificate
if (peerCert instanceof X509Certificate) {
X509Certificate cert = (X509Certificate) peerCert;
return TlsServerEndpoint.getChannelBindingData(cert);

}
}
} catch (CertificateException | SSLException e) {
LOG.debug("Cannot extract X509Certificate from SSL session", e);
}
return new byte[0];
}

private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
Expand Down
27 changes: 24 additions & 3 deletions src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import reactor.util.Loggers;

import javax.annotation.Nullable;
import javax.net.ssl.SSLSession;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

/**
* Value object capturing diagnostic connection context. Allows for log-message post-processing with {@link #getMessage(String) if the logger category for
Expand Down Expand Up @@ -50,6 +52,8 @@ public final class ConnectionContext {

private final String connectionIdPrefix;

private final Supplier<SSLSession> sslSession;

/**
* Create a new {@link ConnectionContext} with a unique connection Id.
*/
Expand All @@ -58,13 +62,15 @@ public ConnectionContext() {
this.connectionCounter = incrementConnectionCounter();
this.connectionIdPrefix = getConnectionIdPrefix();
this.channelId = null;
this.sslSession = () -> null;
}

private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter) {
private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter, Supplier<SSLSession> sslSession) {
this.processId = processId;
this.channelId = channelId;
this.connectionCounter = connectionCounter;
this.connectionIdPrefix = getConnectionIdPrefix();
this.sslSession = sslSession;
}

private String incrementConnectionCounter() {
Expand Down Expand Up @@ -101,14 +107,29 @@ public String getMessage(String original) {
return original;
}

@Nullable
public SSLSession getSslSession() {
return this.sslSession.get();
}

/**
* Create a new {@link ConnectionContext} by associating the {@code channelId}.
*
* @param channelId the channel identifier.
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code channelId}.
*/
public ConnectionContext withChannelId(String channelId) {
return new ConnectionContext(this.processId, channelId, this.connectionCounter);
return new ConnectionContext(this.processId, channelId, this.connectionCounter, this.sslSession);
}

/**
* Create a new {@link ConnectionContext} by associating the {@code sslSession}.
*
* @param sslSession the SSL session supplier.
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code sslSession}.
*/
public ConnectionContext withSslSession(Supplier<SSLSession> sslSession) {
return new ConnectionContext(this.processId, this.channelId, this.connectionCounter, sslSession);
}

/**
Expand All @@ -118,7 +139,7 @@ public ConnectionContext withChannelId(String channelId) {
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code processId}.
*/
public ConnectionContext withProcessId(int processId) {
return new ConnectionContext(processId, this.channelId, this.connectionCounter);
return new ConnectionContext(processId, this.channelId, this.connectionCounter, this.sslSession);
}

}
19 changes: 18 additions & 1 deletion src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
Expand Down Expand Up @@ -148,7 +149,23 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
this.connection = connection;
this.byteBufAllocator = connection.outbound().alloc();
this.context = new ConnectionContext().withChannelId(connection.channel().toString());

ConnectionContext connectionContext = new ConnectionContext().withChannelId(connection.channel().toString());
SslHandler sslHandler = this.connection.channel().pipeline().get(SslHandler.class);

if (sslHandler == null) {
SSLSessionHandlerAdapter handlerAdapter = this.connection.channel().pipeline().get(SSLSessionHandlerAdapter.class);
if (handlerAdapter != null) {
sslHandler = handlerAdapter.getSslHandler();
}
}

if (sslHandler != null) {
SslHandler toUse = sslHandler;
connectionContext = connectionContext.withSslSession(() -> toUse.engine().getSession());
}

this.context = connectionContext;

AtomicReference<Throwable> receiveError = new AtomicReference<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (negotiating) {
if (this.negotiating) {
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
}
super.channelActive(ctx);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (negotiating) {
if (this.negotiating) {
// If we receive channel inactive before negotiated, then the inbound has closed early.
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
completeHandshakeExceptionally(e);
Expand All @@ -63,7 +63,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (negotiating) {
if (this.negotiating) {
ByteBuf buf = (ByteBuf) msg;
char response = (char) buf.readByte();
try {
Expand All @@ -79,7 +79,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
} finally {
buf.release();
negotiating = false;
this.negotiating = false;
}
} else {
super.channelRead(ctx, msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,21 @@ void exchangeSslWithClientCertNoCert() {
.expectError(R2dbcPermissionDeniedException.class));
}

@Test
void exchangeSslWitScram() {
client(
c -> c
.sslRootCert(SERVER.getServerCrt())
.username("test-ssl-scram")
.password("test-ssl-scram"),
c -> c.map(client -> client.createStatement("SELECT 10")
.execute()
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
.as(StepVerifier::create)
.expectNext(10)
.verifyComplete()));
}

@Test
void exchangeSslWithPassword() {
client(
Expand Down
1 change: 1 addition & 0 deletions src/test/resources/pg_hba.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
hostnossl all test all md5
hostnossl all test-scram all scram-sha-256
hostssl all test-ssl all password
hostssl all test-ssl-scram all scram-sha-256
hostssl all test-ssl-with-cert all cert
local all all md5

0 comments on commit 83d39c8

Please sign in to comment.