@@ -95,9 +95,8 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro
95
95
// 3) an index into selects
96
96
// 4) a simple non-aggregate expression
97
97
groupings := make ([]sql.Expression , 0 )
98
- if fromScope .groupBy == nil {
99
- fromScope .initGroupBy ()
100
- }
98
+ fromScope .initGroupBy ()
99
+
101
100
g := fromScope .groupBy
102
101
for _ , e := range groupby {
103
102
var col scopeColumn
@@ -194,9 +193,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
194
193
// - grouping cols projection
195
194
// - aggregate expressions
196
195
// - output projection
197
- if fromScope .groupBy == nil {
198
- fromScope .initGroupBy ()
199
- }
196
+ fromScope .initGroupBy ()
200
197
201
198
group := fromScope .groupBy
202
199
outScope := group .outScope
@@ -257,7 +254,10 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
257
254
return outScope
258
255
}
259
256
260
- func isAggregateFunc (name string ) bool {
257
+ // IsAggregateFunc is a hacky "extension point" to allow for other dialects to declare additional aggregate functions
258
+ var IsAggregateFunc = IsMySQLAggregateFuncName
259
+
260
+ func IsMySQLAggregateFuncName (name string ) bool {
261
261
switch name {
262
262
case "avg" , "bit_and" , "bit_or" , "bit_xor" , "count" ,
263
263
"group_concat" , "json_arrayagg" , "json_objectagg" ,
@@ -278,111 +278,63 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
278
278
b .handleErr (err )
279
279
}
280
280
281
- if inScope .groupBy == nil {
282
- inScope .initGroupBy ()
283
- }
281
+ inScope .initGroupBy ()
284
282
gb := inScope .groupBy
285
283
286
284
if strings .EqualFold (name , "count" ) {
287
285
if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
288
- var agg sql.Aggregation
289
- if e .Distinct {
290
- agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
291
- } else {
292
- agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
293
- }
294
- b .qFlags .Set (sql .QFlagCountStar )
295
- aggName := strings .ToLower (agg .String ())
296
- gf := gb .getAggRef (aggName )
297
- if gf != nil {
298
- // if we've already computed use reference here
299
- return gf
300
- }
301
-
302
- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
303
- id := gb .outScope .newColumn (col )
304
- col .id = id
305
-
306
- agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
307
- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
308
- col .scalar = agg
309
-
310
- gb .addAggStr (col )
311
- return col .scalarGf ()
286
+ return b .buildCountStarAggregate (e , gb )
312
287
}
313
288
}
314
289
315
290
if strings .EqualFold (name , "jsonarray" ) {
316
291
// TODO we don't have any tests for this
317
292
if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
318
- var agg sql.Aggregation
319
- agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
320
- b .qFlags .Set (sql .QFlagStar )
321
-
322
- //if e.Distinct {
323
- // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
324
- //}
325
- aggName := strings .ToLower (agg .String ())
326
- gf := gb .getAggRef (aggName )
327
- if gf != nil {
328
- // if we've already computed use reference here
329
- return gf
330
- }
331
-
332
- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
333
- id := gb .outScope .newColumn (col )
334
-
335
- agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
336
- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
337
- col .scalar = agg
338
-
339
- col .id = id
340
- gb .addAggStr (col )
341
- return col .scalarGf ()
293
+ return b .buildJsonArrayStarAggregate (gb )
342
294
}
343
295
}
344
296
345
297
if strings .EqualFold (name , "any_value" ) {
346
298
b .qFlags .Set (sql .QFlagAnyAgg )
347
299
}
348
300
349
- var args []sql.Expression
350
- for _ , arg := range e .Exprs {
351
- e := b .selectExprToExpression (inScope , arg )
352
- switch e := e .(type ) {
353
- case * expression.GetField :
354
- if e .TableId () == 0 {
355
- // TODO: not sure where this came from but it's not true
356
- // aliases are not valid aggregate arguments, the alias must be masking a column
357
- gf := b .selectExprToExpression (inScope .parent , arg )
358
- var ok bool
359
- e , ok = gf .(* expression.GetField )
360
- if ! ok || e .TableId () == 0 {
361
- b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
362
- }
363
- }
364
- args = append (args , e )
365
- col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
366
- gb .addInCol (col )
367
- case * expression.Star :
368
- err := sql .ErrStarUnsupported .New ()
369
- b .handleErr (err )
370
- case * plan.Subquery :
371
- args = append (args , e )
372
- col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
373
- gb .addInCol (col )
374
- default :
375
- args = append (args , e )
376
- col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
377
- gb .addInCol (col )
378
- }
301
+ args := b .buildAggFunctionArgs (inScope , e , gb )
302
+ agg := b .newAggregation (e , name , args )
303
+
304
+ if name == "count" {
305
+ b .qFlags .Set (sql .QFlagCount )
379
306
}
380
307
308
+ aggType := agg .Type ()
309
+ if name == "avg" || name == "sum" {
310
+ aggType = types .Float64
311
+ }
312
+
313
+ aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
314
+ if id , ok := gb .outScope .getExpr (aggName , true ); ok {
315
+ // if we've already computed use reference here
316
+ gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
317
+ return gf
318
+ }
319
+
320
+ col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
321
+ id := gb .outScope .newColumn (col )
322
+
323
+ agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
324
+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
325
+ col .scalar = agg
326
+
327
+ col .id = id
328
+ gb .addAggStr (col )
329
+ return col .scalarGf ()
330
+ }
331
+
332
+ // newAggregation creates a new aggregation function instanc from the arguments given
333
+ func (b * Builder ) newAggregation (e * ast.FuncExpr , name string , args []sql.Expression ) sql.Aggregation {
381
334
var agg sql.Aggregation
382
335
if e .Distinct && name == "count" {
383
336
agg = aggregation .NewCountDistinct (args ... )
384
337
} else {
385
-
386
338
// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
387
339
// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
388
340
if e .Distinct {
@@ -412,39 +364,104 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
412
364
b .handleErr (err )
413
365
}
414
366
}
367
+ return agg
368
+ }
415
369
416
- if name == "count" {
417
- b .qFlags .Set (sql .QFlagCount )
370
+ // buildAggFunctionArgs builds the arguments for an aggregate function
371
+ func (b * Builder ) buildAggFunctionArgs (inScope * scope , e * ast.FuncExpr , gb * groupBy ) []sql.Expression {
372
+ var args []sql.Expression
373
+ for _ , arg := range e .Exprs {
374
+ e := b .selectExprToExpression (inScope , arg )
375
+ switch e := e .(type ) {
376
+ case * expression.GetField :
377
+ if e .TableId () == 0 {
378
+ // TODO: not sure where this came from but it's not true
379
+ // aliases are not valid aggregate arguments, the alias must be masking a column
380
+ gf := b .selectExprToExpression (inScope .parent , arg )
381
+ var ok bool
382
+ e , ok = gf .(* expression.GetField )
383
+ if ! ok || e .TableId () == 0 {
384
+ b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
385
+ }
386
+ }
387
+ args = append (args , e )
388
+ col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
389
+ gb .addInCol (col )
390
+ case * expression.Star :
391
+ err := sql .ErrStarUnsupported .New ()
392
+ b .handleErr (err )
393
+ case * plan.Subquery :
394
+ args = append (args , e )
395
+ col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
396
+ gb .addInCol (col )
397
+ default :
398
+ args = append (args , e )
399
+ col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
400
+ gb .addInCol (col )
401
+ }
418
402
}
403
+ return args
404
+ }
419
405
420
- aggType := agg .Type ()
421
- if name == "avg" || name == "sum" {
422
- aggType = types .Float64
406
+ // buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
407
+ func (b * Builder ) buildJsonArrayStarAggregate (gb * groupBy ) sql.Expression {
408
+ var agg sql.Aggregation
409
+ agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
410
+ b .qFlags .Set (sql .QFlagStar )
411
+
412
+ // if e.Distinct {
413
+ // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
414
+ // }
415
+ aggName := strings .ToLower (agg .String ())
416
+ gf := gb .getAggRef (aggName )
417
+ if gf != nil {
418
+ // if we've already computed use reference here
419
+ return gf
423
420
}
424
421
425
- aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
426
- if id , ok := gb .outScope .getExpr (aggName , true ); ok {
422
+ col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
423
+ id := gb .outScope .newColumn (col )
424
+
425
+ agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
426
+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
427
+ col .scalar = agg
428
+
429
+ col .id = id
430
+ gb .addAggStr (col )
431
+ return col .scalarGf ()
432
+ }
433
+
434
+ // buildCountStarAggregate builds a COUNT(*) aggregate function
435
+ func (b * Builder ) buildCountStarAggregate (e * ast.FuncExpr , gb * groupBy ) sql.Expression {
436
+ var agg sql.Aggregation
437
+ if e .Distinct {
438
+ agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
439
+ } else {
440
+ agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
441
+ }
442
+ b .qFlags .Set (sql .QFlagCountStar )
443
+ aggName := strings .ToLower (agg .String ())
444
+ gf := gb .getAggRef (aggName )
445
+ if gf != nil {
427
446
// if we've already computed use reference here
428
- gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
429
447
return gf
430
448
}
431
449
432
- col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
450
+ col := scopeColumn {col : strings . ToLower ( agg . String ()) , scalar : agg , typ : agg . Type () , nullable : agg .IsNullable ()}
433
451
id := gb .outScope .newColumn (col )
452
+ col .id = id
434
453
435
454
agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
436
455
gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
437
456
col .scalar = agg
438
457
439
- col .id = id
440
458
gb .addAggStr (col )
441
459
return col .scalarGf ()
442
460
}
443
461
462
+ // buildGroupConcat builds a GROUP_CONCAT aggregate function
444
463
func (b * Builder ) buildGroupConcat (inScope * scope , e * ast.GroupConcatExpr ) sql.Expression {
445
- if inScope .groupBy == nil {
446
- inScope .initGroupBy ()
447
- }
464
+ inScope .initGroupBy ()
448
465
gb := inScope .groupBy
449
466
450
467
args := make ([]sql.Expression , len (e .Exprs ))
@@ -794,7 +811,7 @@ func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where)
794
811
return false , nil
795
812
case * ast.FuncExpr :
796
813
name := n .Name .Lowered ()
797
- if isAggregateFunc (name ) {
814
+ if IsAggregateFunc (name ) {
798
815
// record aggregate
799
816
// TODO: this should get projScope as well
800
817
_ = b .buildAggregateFunc (fromScope , name , n )
@@ -874,9 +891,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
874
891
if having == nil {
875
892
return
876
893
}
877
- if fromScope .groupBy == nil {
878
- fromScope .initGroupBy ()
879
- }
894
+ fromScope .initGroupBy ()
880
895
881
896
havingScope := b .newScope ()
882
897
if fromScope .parent != nil {
0 commit comments