@@ -175,7 +175,8 @@ private Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
175
175
}
176
176
break ;
177
177
case JOIN :
178
- return dealJoinNode ((SqlJoin ) sqlNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
178
+ Set <Tuple2 <String , String >> joinFieldSet = Sets .newHashSet ();
179
+ return dealJoinNode ((SqlJoin ) sqlNode , sideTableSet , queueInfo , parentWhere , parentSelectList , joinFieldSet );
179
180
case AS :
180
181
SqlNode info = ((SqlBasicCall )sqlNode ).getOperands ()[0 ];
181
182
SqlNode alias = ((SqlBasicCall ) sqlNode ).getOperands ()[1 ];
@@ -248,7 +249,7 @@ private SqlBasicCall buildAsSqlNode(String internalTableName, SqlNode newSource)
248
249
* @return
249
250
*/
250
251
private JoinInfo dealJoinNode (SqlJoin joinNode , Set <String > sideTableSet , Queue <Object > queueInfo ,
251
- SqlNode parentWhere , SqlNodeList parentSelectList ) {
252
+ SqlNode parentWhere , SqlNodeList parentSelectList , Set < Tuple2 < String , String >> joinFieldSet ) {
252
253
SqlNode leftNode = joinNode .getLeft ();
253
254
SqlNode rightNode = joinNode .getRight ();
254
255
JoinType joinType = joinNode .getJoinType ();
@@ -261,12 +262,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
261
262
262
263
//如果是连续join 判断是否已经处理过添加到执行队列
263
264
Boolean alreadyOffer = false ;
265
+ extractJoinField (joinNode .getCondition (), joinFieldSet );
264
266
265
267
if (leftNode .getKind () == IDENTIFIER ){
266
268
leftTbName = leftNode .toString ();
267
269
} else if (leftNode .getKind () == JOIN ) {
268
270
//处理连续join
269
- Tuple2 <Boolean , SqlBasicCall > nestJoinResult = dealNestJoin ((SqlJoin ) leftNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
271
+ Tuple2 <Boolean , SqlBasicCall > nestJoinResult = dealNestJoin ((SqlJoin ) leftNode , sideTableSet ,
272
+ queueInfo , parentWhere , parentSelectList , joinFieldSet );
270
273
alreadyOffer = nestJoinResult .f0 ;
271
274
leftTbName = nestJoinResult .f1 .getOperands ()[0 ].toString ();
272
275
leftTbAlias = nestJoinResult .f1 .getOperands ()[1 ].toString ();
@@ -320,7 +323,8 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
320
323
}
321
324
322
325
if (tableInfo .getLeftNode ().getKind () != AS ){
323
- extractTemporaryQuery (tableInfo .getLeftNode (), tableInfo .getLeftTableAlias (), (SqlBasicCall ) parentWhere , parentSelectList , queueInfo );
326
+ extractTemporaryQuery (tableInfo .getLeftNode (), tableInfo .getLeftTableAlias (), (SqlBasicCall ) parentWhere ,
327
+ parentSelectList , queueInfo , joinFieldSet );
324
328
}else {
325
329
SqlKind asNodeFirstKind = ((SqlBasicCall )tableInfo .getLeftNode ()).operands [0 ].getKind ();
326
330
if (asNodeFirstKind == SELECT ){
@@ -331,11 +335,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
331
335
return tableInfo ;
332
336
}
333
337
338
+
334
339
//构建新的查询
335
- private Tuple2 <Boolean , SqlBasicCall > dealNestJoin (SqlJoin joinNode , Set <String > sideTableSet , Queue <Object > queueInfo , SqlNode parentWhere , SqlNodeList selectList ){
340
+ private Tuple2 <Boolean , SqlBasicCall > dealNestJoin (SqlJoin joinNode , Set <String > sideTableSet ,
341
+ Queue <Object > queueInfo , SqlNode parentWhere ,
342
+ SqlNodeList selectList , Set <Tuple2 <String , String >> joinFieldSet ){
336
343
SqlNode rightNode = joinNode .getRight ();
337
344
Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , selectList );
338
- JoinInfo joinInfo = dealJoinNode (joinNode , sideTableSet , queueInfo , parentWhere , selectList );
345
+ JoinInfo joinInfo = dealJoinNode (joinNode , sideTableSet , queueInfo , parentWhere , selectList , joinFieldSet );
339
346
340
347
String rightTableName = rightTableNameAndAlias .f0 ;
341
348
boolean rightIsSide = checkIsSideTable (rightTableName , sideTableSet );
@@ -352,23 +359,23 @@ private Tuple2<Boolean, SqlBasicCall> dealNestJoin(SqlJoin joinNode, Set<String>
352
359
return Tuple2 .of (alreadyOffer , TableUtils .buildAsNodeByJoinInfo (joinInfo , null , null ));
353
360
}
354
361
355
- public boolean checkAndRemoveCondition (Set <String > fromTableNameSet , SqlBasicCall parentWhere , List <SqlBasicCall > extractContition ){
362
+ public boolean checkAndRemoveCondition (Set <String > fromTableNameSet , SqlBasicCall parentWhere , List <SqlBasicCall > extractCondition ){
356
363
357
364
if (parentWhere == null ){
358
365
return false ;
359
366
}
360
367
361
368
SqlKind kind = parentWhere .getKind ();
362
369
if (kind == AND ){
363
- boolean removeLeft = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[0 ], extractContition );
364
- boolean removeRight = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[1 ], extractContition );
370
+ boolean removeLeft = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[0 ], extractCondition );
371
+ boolean removeRight = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[1 ], extractCondition );
365
372
//DO remove
366
373
if (removeLeft ){
367
- extractContition .add (removeWhereConditionNode (parentWhere , 0 ));
374
+ extractCondition .add (removeWhereConditionNode (parentWhere , 0 ));
368
375
}
369
376
370
377
if (removeRight ){
371
- extractContition .add (removeWhereConditionNode (parentWhere , 1 ));
378
+ extractCondition .add (removeWhereConditionNode (parentWhere , 1 ));
372
379
}
373
380
374
381
return false ;
@@ -385,7 +392,8 @@ public boolean checkAndRemoveCondition(Set<String> fromTableNameSet, SqlBasicCal
385
392
}
386
393
387
394
private void extractTemporaryQuery (SqlNode node , String tableAlias , SqlBasicCall parentWhere ,
388
- SqlNodeList parentSelectList , Queue <Object > queueInfo ){
395
+ SqlNodeList parentSelectList , Queue <Object > queueInfo ,
396
+ Set <Tuple2 <String , String >> joinFieldSet ){
389
397
try {
390
398
//父一级的where 条件中如果只和临时查询相关的条件都截取进来
391
399
Set <String > fromTableNameSet = Sets .newHashSet ();
@@ -394,8 +402,9 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
394
402
getFromTableInfo (node , fromTableNameSet );
395
403
checkAndRemoveCondition (fromTableNameSet , parentWhere , extractCondition );
396
404
397
- List <String > extractSelectField = extractSelectList (parentSelectList , fromTableNameSet );
398
- String extractSelectFieldStr = buildSelectNode (extractSelectField );
405
+ Set <String > extractSelectField = extractSelectFields (parentSelectList , fromTableNameSet );
406
+ Set <String > fieldFromJoinCondition = extractSelectFieldFromJoinCondition (joinFieldSet , fromTableNameSet );
407
+ String extractSelectFieldStr = buildSelectNode (extractSelectField , fieldFromJoinCondition );
399
408
String extractConditionStr = buildCondition (extractCondition );
400
409
401
410
String tmpSelectSql = String .format (SELECT_TEMP_SQL ,
@@ -425,19 +434,50 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
425
434
* @param fromTableNameSet
426
435
* @return
427
436
*/
428
- private List <String > extractSelectList (SqlNodeList parentSelectList , Set <String > fromTableNameSet ){
429
- List <String > extractFieldList = Lists . newArrayList ();
437
+ private Set <String > extractSelectFields (SqlNodeList parentSelectList , Set <String > fromTableNameSet ){
438
+ Set <String > extractFieldList = Sets . newHashSet ();
430
439
for (SqlNode selectNode : parentSelectList .getList ()){
431
440
extractSelectField (selectNode , extractFieldList , fromTableNameSet );
432
441
}
433
442
434
443
return extractFieldList ;
435
444
}
436
445
437
- private void extractSelectField (SqlNode selectNode , List <String > extractFieldList , Set <String > fromTableNameSet ){
446
+ private Set <String > extractSelectFieldFromJoinCondition (Set <Tuple2 <String , String >> joinFieldSet , Set <String > fromTableNameSet ){
447
+ Set <String > extractFieldList = Sets .newHashSet ();
448
+ for (Tuple2 <String , String > field : joinFieldSet ){
449
+ if (fromTableNameSet .contains (field .f0 )){
450
+ extractFieldList .add (field .f0 + "." + field .f1 );
451
+ }
452
+ }
453
+
454
+ return extractFieldList ;
455
+ }
456
+
457
+ /**
458
+ * 从join的条件中获取字段信息
459
+ * @param condition
460
+ * @param joinFieldSet
461
+ */
462
+ private void extractJoinField (SqlNode condition , Set <Tuple2 <String , String >> joinFieldSet ){
463
+ SqlKind joinKind = condition .getKind ();
464
+ if ( joinKind == AND ){
465
+ extractJoinField (((SqlBasicCall )condition ).operands [0 ], joinFieldSet );
466
+ extractJoinField (((SqlBasicCall )condition ).operands [1 ], joinFieldSet );
467
+ }else if ( joinKind == EQUALS ){
468
+ extractJoinField (((SqlBasicCall )condition ).operands [0 ], joinFieldSet );
469
+ extractJoinField (((SqlBasicCall )condition ).operands [1 ], joinFieldSet );
470
+ }else {
471
+ Preconditions .checkState (((SqlIdentifier )condition ).names .size () == 2 , "join condition must be format table.field" );
472
+ Tuple2 <String , String > tuple2 = Tuple2 .of (((SqlIdentifier )condition ).names .get (0 ), ((SqlIdentifier )condition ).names .get (1 ));
473
+ joinFieldSet .add (tuple2 );
474
+ }
475
+ }
476
+
477
+ private void extractSelectField (SqlNode selectNode , Set <String > extractFieldSet , Set <String > fromTableNameSet ){
438
478
if (selectNode .getKind () == AS ) {
439
479
SqlNode leftNode = ((SqlBasicCall ) selectNode ).getOperands ()[0 ];
440
- extractSelectField (leftNode , extractFieldList , fromTableNameSet );
480
+ extractSelectField (leftNode , extractFieldSet , fromTableNameSet );
441
481
442
482
}else if (selectNode .getKind () == IDENTIFIER ) {
443
483
SqlIdentifier sqlIdentifier = (SqlIdentifier ) selectNode ;
@@ -448,7 +488,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
448
488
449
489
String tableName = sqlIdentifier .names .get (0 );
450
490
if (fromTableNameSet .contains (tableName )){
451
- extractFieldList .add (sqlIdentifier .toString ());
491
+ extractFieldSet .add (sqlIdentifier .toString ());
452
492
}
453
493
454
494
}else if ( AGGREGATE .contains (selectNode .getKind ())
@@ -493,7 +533,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
493
533
continue ;
494
534
}
495
535
496
- extractSelectField (sqlNode , extractFieldList , fromTableNameSet );
536
+ extractSelectField (sqlNode , extractFieldSet , fromTableNameSet );
497
537
}
498
538
499
539
}else if (selectNode .getKind () == CASE ){
@@ -505,15 +545,15 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
505
545
506
546
for (int i =0 ; i <whenOperands .size (); i ++){
507
547
SqlNode oneOperand = whenOperands .get (i );
508
- extractSelectField (oneOperand , extractFieldList , fromTableNameSet );
548
+ extractSelectField (oneOperand , extractFieldSet , fromTableNameSet );
509
549
}
510
550
511
551
for (int i =0 ; i <thenOperands .size (); i ++){
512
552
SqlNode oneOperand = thenOperands .get (i );
513
- extractSelectField (oneOperand , extractFieldList , fromTableNameSet );
553
+ extractSelectField (oneOperand , extractFieldSet , fromTableNameSet );
514
554
}
515
555
516
- extractSelectField (elseNode , extractFieldList , fromTableNameSet );
556
+ extractSelectField (elseNode , extractFieldSet , fromTableNameSet );
517
557
}else {
518
558
//do nothing
519
559
}
@@ -566,12 +606,14 @@ public String buildCondition(List<SqlBasicCall> conditionList){
566
606
return " where " + StringUtils .join (conditionList , " AND " );
567
607
}
568
608
569
- public String buildSelectNode (List <String > extractSelectField ){
609
+ public String buildSelectNode (Set <String > extractSelectField , Set < String > joinFieldSet ){
570
610
if (CollectionUtils .isEmpty (extractSelectField )){
571
611
throw new RuntimeException ("no field is used" );
572
612
}
573
613
574
- return StringUtils .join (extractSelectField , "," );
614
+ Sets .SetView view = Sets .union (extractSelectField , joinFieldSet );
615
+
616
+ return StringUtils .join (view , "," );
575
617
}
576
618
577
619
public SqlBasicCall buildDefaultCondition (){
0 commit comments