Skip to content

Commit

Permalink
[Fix] Fix SupportsFilterPushDown bug when flinksql unionall (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaoMiMao authored Aug 30, 2024
1 parent f108c3c commit 8160dd2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.flink.api.connector.source.SplitEnumeratorContext;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.core.io.SimpleVersionedSerializer;
import org.apache.flink.util.StringUtils;

import org.apache.doris.flink.cfg.DorisOptions;
import org.apache.doris.flink.cfg.DorisReadOptions;
Expand Down Expand Up @@ -64,14 +65,18 @@ public class DorisSource<OUT>
private final Boundedness boundedness;
private final DorisDeserializationSchema<OUT> deserializer;

private final List<String> resolvedFilterQuery;

public DorisSource(
DorisOptions options,
DorisReadOptions readOptions,
Boundedness boundedness,
List<String> resolvedFilterQuery,
DorisDeserializationSchema<OUT> deserializer) {
this.options = options;
this.readOptions = readOptions;
this.boundedness = boundedness;
this.resolvedFilterQuery = resolvedFilterQuery;
this.deserializer = deserializer;
}

Expand All @@ -95,6 +100,15 @@ public SourceReader<OUT, DorisSourceSplit> createReader(SourceReaderContext read
public SplitEnumerator<DorisSourceSplit, PendingSplitsCheckpoint> createEnumerator(
SplitEnumeratorContext<DorisSourceSplit> context) throws Exception {
List<DorisSourceSplit> dorisSourceSplits = new ArrayList<>();
if (!resolvedFilterQuery.isEmpty()) {
String filterQuery = String.join(" AND ", resolvedFilterQuery);
if (StringUtils.isNullOrWhitespaceOnly(readOptions.getFilterQuery())) {
readOptions.setFilterQuery(filterQuery);
} else {
readOptions.setFilterQuery(
String.join(" AND ", readOptions.getFilterQuery(), filterQuery));
}
}
List<PartitionDefinition> partitions =
RestService.findPartitions(options, readOptions, LOG);
for (int index = 0; index < partitions.size(); index++) {
Expand Down Expand Up @@ -147,6 +161,7 @@ public static class DorisSourceBuilder<OUT> {
// Boundedness
private Boundedness boundedness;
private DorisDeserializationSchema<OUT> deserializer;
private List<String> resolvedFilterQuery = new ArrayList<>();

DorisSourceBuilder() {
boundedness = Boundedness.BOUNDED;
Expand All @@ -173,11 +188,17 @@ public DorisSourceBuilder<OUT> setDeserializer(
return this;
}

public DorisSourceBuilder<OUT> setResolvedFilterQuery(List<String> resolvedFilterQuery) {
this.resolvedFilterQuery = resolvedFilterQuery;
return this;
}

public DorisSource<OUT> build() {
if (readOptions == null) {
readOptions = DorisReadOptions.builder().build();
}
return new DorisSource<>(options, readOptions, boundedness, deserializer);
return new DorisSource<>(
options, readOptions, boundedness, resolvedFilterQuery, deserializer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ public ChangelogMode getChangelogMode() {

@Override
public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderContext) {
if (StringUtils.isNullOrWhitespaceOnly(readOptions.getFilterQuery())) {
String filterQuery = resolvedFilterQuery.stream().collect(Collectors.joining(" AND "));
readOptions.setFilterQuery(filterQuery);
}
if (StringUtils.isNullOrWhitespaceOnly(readOptions.getReadFields())) {
String[] selectFields =
DataType.getFieldNames(physicalRowDataType).toArray(new String[0]);
Expand Down Expand Up @@ -127,6 +123,7 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext runtimeProviderCon
DorisSource.<RowData>builder()
.setDorisReadOptions(readOptions)
.setDorisOptions(options)
.setResolvedFilterQuery(resolvedFilterQuery)
.setDeserializer(
new RowDataDeserializationSchema(
(RowType) physicalRowDataType.getLogicalType()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public class DorisSourceITCase extends DorisTestBase {
static final String TABLE_READ_TBL_OLD_API = "tbl_read_tbl_old_api";
static final String TABLE_READ_TBL_ALL_OPTIONS = "tbl_read_tbl_all_options";
static final String TABLE_READ_TBL_PUSH_DOWN = "tbl_read_tbl_push_down";
static final String TABLE_READ_TBL_PUSH_DOWN_WITH_UNION_ALL =
"tbl_read_tbl_push_down_with_union_all";

@Test
public void testSource() throws Exception {
Expand Down Expand Up @@ -77,7 +79,7 @@ public void testSource() throws Exception {
actual.add(iterator.next().toString());
}
}
List<String> expected = Arrays.asList("[doris, 18]", "[flink, 10]");
List<String> expected = Arrays.asList("[doris, 18]", "[flink, 10]", "[apache, 12]");
Assert.assertArrayEquals(actual.toArray(), expected.toArray());
}

Expand All @@ -102,7 +104,7 @@ options, new SimpleListDeserializationSchema()))
actual.add(iterator.next().toString());
}
}
List<String> expected = Arrays.asList("[doris, 18]", "[flink, 10]");
List<String> expected = Arrays.asList("[doris, 18]", "[flink, 10]", "[apache, 12]");
Assert.assertArrayEquals(actual.toArray(), expected.toArray());
}

Expand Down Expand Up @@ -136,7 +138,7 @@ public void testTableSource() throws Exception {
actual.add(iterator.next().toString());
}
}
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]"};
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]", "+I[apache, 12]"};
Assert.assertArrayEquals(expected, actual.toArray());

// fitler query
Expand Down Expand Up @@ -182,7 +184,7 @@ public void testTableSourceOldApi() throws Exception {
actual.add(iterator.next().toString());
}
}
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]"};
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]", "+I[apache, 12]"};
Assert.assertArrayEquals(expected, actual.toArray());
}

Expand Down Expand Up @@ -228,7 +230,7 @@ public void testTableSourceAllOptions() throws Exception {
actual.add(iterator.next().toString());
}
}
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]"};
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]", "+I[apache, 12]"};
Assert.assertArrayEquals(expected, actual.toArray());
}

Expand All @@ -242,6 +244,7 @@ public void testTableSourceFilterAndProjectionPushDown() throws Exception {
String sourceDDL =
String.format(
"CREATE TABLE doris_source ("
+ " name STRING,"
+ " age INT"
+ ") WITH ("
+ " 'connector' = 'doris',"
Expand All @@ -267,6 +270,46 @@ public void testTableSourceFilterAndProjectionPushDown() throws Exception {
Assert.assertArrayEquals(expected, actual.toArray());
}

@Test
public void testTableSourceFilterWithUnionAll() throws Exception {
initializeTable(TABLE_READ_TBL_PUSH_DOWN_WITH_UNION_ALL);
final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
final StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

String sourceDDL =
String.format(
"CREATE TABLE doris_source ("
+ " name STRING,"
+ " age INT"
+ ") WITH ("
+ " 'connector' = 'doris',"
+ " 'fenodes' = '%s',"
+ " 'table.identifier' = '%s',"
+ " 'username' = '%s',"
+ " 'password' = '%s'"
+ ")",
getFenodes(),
DATABASE + "." + TABLE_READ_TBL_PUSH_DOWN_WITH_UNION_ALL,
USERNAME,
PASSWORD);
tEnv.executeSql(sourceDDL);
TableResult tableResult =
tEnv.executeSql(
" SELECT * FROM doris_source where age = '18'"
+ " UNION ALL "
+ "SELECT * FROM doris_source where age = '10' ");

List<String> actual = new ArrayList<>();
try (CloseableIterator<Row> iterator = tableResult.collect()) {
while (iterator.hasNext()) {
actual.add(iterator.next().toString());
}
}
String[] expected = new String[] {"+I[doris, 18]", "+I[flink, 10]"};
Assert.assertArrayEquals(expected, actual.toArray());
}

private void initializeTable(String table) throws Exception {
try (Connection connection =
DriverManager.getConnection(
Expand All @@ -288,6 +331,8 @@ private void initializeTable(String table) throws Exception {
String.format("insert into %s.%s values ('doris',18)", DATABASE, table));
statement.execute(
String.format("insert into %s.%s values ('flink',10)", DATABASE, table));
statement.execute(
String.format("insert into %s.%s values ('apache',12)", DATABASE, table));
}
}
}

0 comments on commit 8160dd2

Please sign in to comment.