Skip to content

Commit 669e4b1

Browse files
committed
从join on 关联的条件中获取字段信息
1 parent ce71b01 commit 669e4b1

File tree

1 file changed

+67
-25
lines changed

1 file changed

+67
-25
lines changed

core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ private Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
175175
}
176176
break;
177177
case JOIN:
178-
return dealJoinNode((SqlJoin) sqlNode, sideTableSet, queueInfo, parentWhere, parentSelectList);
178+
Set<Tuple2<String, String>> joinFieldSet = Sets.newHashSet();
179+
return dealJoinNode((SqlJoin) sqlNode, sideTableSet, queueInfo, parentWhere, parentSelectList, joinFieldSet);
179180
case AS:
180181
SqlNode info = ((SqlBasicCall)sqlNode).getOperands()[0];
181182
SqlNode alias = ((SqlBasicCall) sqlNode).getOperands()[1];
@@ -248,7 +249,7 @@ private SqlBasicCall buildAsSqlNode(String internalTableName, SqlNode newSource)
248249
* @return
249250
*/
250251
private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<Object> queueInfo,
251-
SqlNode parentWhere, SqlNodeList parentSelectList) {
252+
SqlNode parentWhere, SqlNodeList parentSelectList, Set<Tuple2<String, String>> joinFieldSet) {
252253
SqlNode leftNode = joinNode.getLeft();
253254
SqlNode rightNode = joinNode.getRight();
254255
JoinType joinType = joinNode.getJoinType();
@@ -261,12 +262,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
261262

262263
//如果是连续join 判断是否已经处理过添加到执行队列
263264
Boolean alreadyOffer = false;
265+
extractJoinField(joinNode.getCondition(), joinFieldSet);
264266

265267
if(leftNode.getKind() == IDENTIFIER){
266268
leftTbName = leftNode.toString();
267269
} else if (leftNode.getKind() == JOIN) {
268270
//处理连续join
269-
Tuple2<Boolean, SqlBasicCall> nestJoinResult = dealNestJoin((SqlJoin) leftNode, sideTableSet, queueInfo, parentWhere, parentSelectList);
271+
Tuple2<Boolean, SqlBasicCall> nestJoinResult = dealNestJoin((SqlJoin) leftNode, sideTableSet,
272+
queueInfo, parentWhere, parentSelectList, joinFieldSet);
270273
alreadyOffer = nestJoinResult.f0;
271274
leftTbName = nestJoinResult.f1.getOperands()[0].toString();
272275
leftTbAlias = nestJoinResult.f1.getOperands()[1].toString();
@@ -320,7 +323,8 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
320323
}
321324

322325
if(tableInfo.getLeftNode().getKind() != AS){
323-
extractTemporaryQuery(tableInfo.getLeftNode(), tableInfo.getLeftTableAlias(), (SqlBasicCall) parentWhere, parentSelectList, queueInfo);
326+
extractTemporaryQuery(tableInfo.getLeftNode(), tableInfo.getLeftTableAlias(), (SqlBasicCall) parentWhere,
327+
parentSelectList, queueInfo, joinFieldSet);
324328
}else {
325329
SqlKind asNodeFirstKind = ((SqlBasicCall)tableInfo.getLeftNode()).operands[0].getKind();
326330
if(asNodeFirstKind == SELECT){
@@ -331,11 +335,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
331335
return tableInfo;
332336
}
333337

338+
334339
//构建新的查询
335-
private Tuple2<Boolean, SqlBasicCall> dealNestJoin(SqlJoin joinNode, Set<String> sideTableSet, Queue<Object> queueInfo, SqlNode parentWhere, SqlNodeList selectList){
340+
private Tuple2<Boolean, SqlBasicCall> dealNestJoin(SqlJoin joinNode, Set<String> sideTableSet,
341+
Queue<Object> queueInfo, SqlNode parentWhere,
342+
SqlNodeList selectList, Set<Tuple2<String, String>> joinFieldSet){
336343
SqlNode rightNode = joinNode.getRight();
337344
Tuple2<String, String> rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, selectList);
338-
JoinInfo joinInfo = dealJoinNode(joinNode, sideTableSet, queueInfo, parentWhere, selectList);
345+
JoinInfo joinInfo = dealJoinNode(joinNode, sideTableSet, queueInfo, parentWhere, selectList, joinFieldSet);
339346

340347
String rightTableName = rightTableNameAndAlias.f0;
341348
boolean rightIsSide = checkIsSideTable(rightTableName, sideTableSet);
@@ -352,23 +359,23 @@ private Tuple2<Boolean, SqlBasicCall> dealNestJoin(SqlJoin joinNode, Set<String>
352359
return Tuple2.of(alreadyOffer, TableUtils.buildAsNodeByJoinInfo(joinInfo, null, null));
353360
}
354361

355-
public boolean checkAndRemoveCondition(Set<String> fromTableNameSet, SqlBasicCall parentWhere, List<SqlBasicCall> extractContition){
362+
public boolean checkAndRemoveCondition(Set<String> fromTableNameSet, SqlBasicCall parentWhere, List<SqlBasicCall> extractCondition){
356363

357364
if(parentWhere == null){
358365
return false;
359366
}
360367

361368
SqlKind kind = parentWhere.getKind();
362369
if(kind == AND){
363-
boolean removeLeft = checkAndRemoveCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[0], extractContition);
364-
boolean removeRight = checkAndRemoveCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[1], extractContition);
370+
boolean removeLeft = checkAndRemoveCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[0], extractCondition);
371+
boolean removeRight = checkAndRemoveCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[1], extractCondition);
365372
//DO remove
366373
if(removeLeft){
367-
extractContition.add(removeWhereConditionNode(parentWhere, 0));
374+
extractCondition.add(removeWhereConditionNode(parentWhere, 0));
368375
}
369376

370377
if(removeRight){
371-
extractContition.add(removeWhereConditionNode(parentWhere, 1));
378+
extractCondition.add(removeWhereConditionNode(parentWhere, 1));
372379
}
373380

374381
return false;
@@ -385,7 +392,8 @@ public boolean checkAndRemoveCondition(Set<String> fromTableNameSet, SqlBasicCal
385392
}
386393

387394
private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall parentWhere,
388-
SqlNodeList parentSelectList, Queue<Object> queueInfo){
395+
SqlNodeList parentSelectList, Queue<Object> queueInfo,
396+
Set<Tuple2<String, String>> joinFieldSet){
389397
try{
390398
//父一级的where 条件中如果只和临时查询相关的条件都截取进来
391399
Set<String> fromTableNameSet = Sets.newHashSet();
@@ -394,8 +402,9 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
394402
getFromTableInfo(node, fromTableNameSet);
395403
checkAndRemoveCondition(fromTableNameSet, parentWhere, extractCondition);
396404

397-
List<String> extractSelectField = extractSelectList(parentSelectList, fromTableNameSet);
398-
String extractSelectFieldStr = buildSelectNode(extractSelectField);
405+
Set<String> extractSelectField = extractSelectFields(parentSelectList, fromTableNameSet);
406+
Set<String> fieldFromJoinCondition = extractSelectFieldFromJoinCondition(joinFieldSet, fromTableNameSet);
407+
String extractSelectFieldStr = buildSelectNode(extractSelectField, fieldFromJoinCondition);
399408
String extractConditionStr = buildCondition(extractCondition);
400409

401410
String tmpSelectSql = String.format(SELECT_TEMP_SQL,
@@ -425,19 +434,50 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
425434
* @param fromTableNameSet
426435
* @return
427436
*/
428-
private List<String> extractSelectList(SqlNodeList parentSelectList, Set<String> fromTableNameSet){
429-
List<String> extractFieldList = Lists.newArrayList();
437+
private Set<String> extractSelectFields(SqlNodeList parentSelectList, Set<String> fromTableNameSet){
438+
Set<String> extractFieldList = Sets.newHashSet();
430439
for(SqlNode selectNode : parentSelectList.getList()){
431440
extractSelectField(selectNode, extractFieldList, fromTableNameSet);
432441
}
433442

434443
return extractFieldList;
435444
}
436445

437-
private void extractSelectField(SqlNode selectNode, List<String> extractFieldList, Set<String> fromTableNameSet){
446+
private Set<String> extractSelectFieldFromJoinCondition(Set<Tuple2<String, String>> joinFieldSet, Set<String> fromTableNameSet){
447+
Set<String> extractFieldList = Sets.newHashSet();
448+
for(Tuple2<String, String> field : joinFieldSet){
449+
if(fromTableNameSet.contains(field.f0)){
450+
extractFieldList.add(field.f0 + "." + field.f1);
451+
}
452+
}
453+
454+
return extractFieldList;
455+
}
456+
457+
/**
458+
* 从join的条件中获取字段信息
459+
* @param condition
460+
* @param joinFieldSet
461+
*/
462+
private void extractJoinField(SqlNode condition, Set<Tuple2<String, String>> joinFieldSet){
463+
SqlKind joinKind = condition.getKind();
464+
if( joinKind == AND ){
465+
extractJoinField(((SqlBasicCall)condition).operands[0], joinFieldSet);
466+
extractJoinField(((SqlBasicCall)condition).operands[1], joinFieldSet);
467+
}else if( joinKind == EQUALS ){
468+
extractJoinField(((SqlBasicCall)condition).operands[0], joinFieldSet);
469+
extractJoinField(((SqlBasicCall)condition).operands[1], joinFieldSet);
470+
}else{
471+
Preconditions.checkState(((SqlIdentifier)condition).names.size() == 2, "join condition must be format table.field");
472+
Tuple2<String, String> tuple2 = Tuple2.of(((SqlIdentifier)condition).names.get(0), ((SqlIdentifier)condition).names.get(1));
473+
joinFieldSet.add(tuple2);
474+
}
475+
}
476+
477+
private void extractSelectField(SqlNode selectNode, Set<String> extractFieldSet, Set<String> fromTableNameSet){
438478
if (selectNode.getKind() == AS) {
439479
SqlNode leftNode = ((SqlBasicCall) selectNode).getOperands()[0];
440-
extractSelectField(leftNode, extractFieldList, fromTableNameSet);
480+
extractSelectField(leftNode, extractFieldSet, fromTableNameSet);
441481

442482
}else if(selectNode.getKind() == IDENTIFIER) {
443483
SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode;
@@ -448,7 +488,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
448488

449489
String tableName = sqlIdentifier.names.get(0);
450490
if(fromTableNameSet.contains(tableName)){
451-
extractFieldList.add(sqlIdentifier.toString());
491+
extractFieldSet.add(sqlIdentifier.toString());
452492
}
453493

454494
}else if( AGGREGATE.contains(selectNode.getKind())
@@ -493,7 +533,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
493533
continue;
494534
}
495535

496-
extractSelectField(sqlNode, extractFieldList, fromTableNameSet);
536+
extractSelectField(sqlNode, extractFieldSet, fromTableNameSet);
497537
}
498538

499539
}else if(selectNode.getKind() == CASE){
@@ -505,15 +545,15 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
505545

506546
for(int i=0; i<whenOperands.size(); i++){
507547
SqlNode oneOperand = whenOperands.get(i);
508-
extractSelectField(oneOperand, extractFieldList, fromTableNameSet);
548+
extractSelectField(oneOperand, extractFieldSet, fromTableNameSet);
509549
}
510550

511551
for(int i=0; i<thenOperands.size(); i++){
512552
SqlNode oneOperand = thenOperands.get(i);
513-
extractSelectField(oneOperand, extractFieldList, fromTableNameSet);
553+
extractSelectField(oneOperand, extractFieldSet, fromTableNameSet);
514554
}
515555

516-
extractSelectField(elseNode, extractFieldList, fromTableNameSet);
556+
extractSelectField(elseNode, extractFieldSet, fromTableNameSet);
517557
}else {
518558
//do nothing
519559
}
@@ -566,12 +606,14 @@ public String buildCondition(List<SqlBasicCall> conditionList){
566606
return " where " + StringUtils.join(conditionList, " AND ");
567607
}
568608

569-
public String buildSelectNode(List<String> extractSelectField){
609+
public String buildSelectNode(Set<String> extractSelectField, Set<String> joinFieldSet){
570610
if(CollectionUtils.isEmpty(extractSelectField)){
571611
throw new RuntimeException("no field is used");
572612
}
573613

574-
return StringUtils.join(extractSelectField, ",");
614+
Sets.SetView view = Sets.union(extractSelectField, joinFieldSet);
615+
616+
return StringUtils.join(view, ",");
575617
}
576618

577619
public SqlBasicCall buildDefaultCondition(){

0 commit comments

Comments
 (0)