Skip to content

Commit 65f2e3b

Browse files
authored
Merge pull request #2992 from dolthub/zachmu/aggs
Hacky extension point for aggregate function determination
2 parents 6e3c190 + 78556f7 commit 65f2e3b

File tree

4 files changed

+126
-109
lines changed

4 files changed

+126
-109
lines changed

sql/planbuilder/aggregates.go

Lines changed: 121 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro
9595
// 3) an index into selects
9696
// 4) a simple non-aggregate expression
9797
groupings := make([]sql.Expression, 0)
98-
if fromScope.groupBy == nil {
99-
fromScope.initGroupBy()
100-
}
98+
fromScope.initGroupBy()
99+
101100
g := fromScope.groupBy
102101
for _, e := range groupby {
103102
var col scopeColumn
@@ -194,9 +193,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
194193
// - grouping cols projection
195194
// - aggregate expressions
196195
// - output projection
197-
if fromScope.groupBy == nil {
198-
fromScope.initGroupBy()
199-
}
196+
fromScope.initGroupBy()
200197

201198
group := fromScope.groupBy
202199
outScope := group.outScope
@@ -257,7 +254,10 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
257254
return outScope
258255
}
259256

260-
func isAggregateFunc(name string) bool {
257+
// IsAggregateFunc is a hacky "extension point" to allow for other dialects to declare additional aggregate functions
258+
var IsAggregateFunc = IsMySQLAggregateFuncName
259+
260+
func IsMySQLAggregateFuncName(name string) bool {
261261
switch name {
262262
case "avg", "bit_and", "bit_or", "bit_xor", "count",
263263
"group_concat", "json_arrayagg", "json_objectagg",
@@ -278,111 +278,63 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
278278
b.handleErr(err)
279279
}
280280

281-
if inScope.groupBy == nil {
282-
inScope.initGroupBy()
283-
}
281+
inScope.initGroupBy()
284282
gb := inScope.groupBy
285283

286284
if strings.EqualFold(name, "count") {
287285
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
288-
var agg sql.Aggregation
289-
if e.Distinct {
290-
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
291-
} else {
292-
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
293-
}
294-
b.qFlags.Set(sql.QFlagCountStar)
295-
aggName := strings.ToLower(agg.String())
296-
gf := gb.getAggRef(aggName)
297-
if gf != nil {
298-
// if we've already computed use reference here
299-
return gf
300-
}
301-
302-
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
303-
id := gb.outScope.newColumn(col)
304-
col.id = id
305-
306-
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
307-
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
308-
col.scalar = agg
309-
310-
gb.addAggStr(col)
311-
return col.scalarGf()
286+
return b.buildCountStarAggregate(e, gb)
312287
}
313288
}
314289

315290
if strings.EqualFold(name, "jsonarray") {
316291
// TODO we don't have any tests for this
317292
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
318-
var agg sql.Aggregation
319-
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
320-
b.qFlags.Set(sql.QFlagStar)
321-
322-
//if e.Distinct {
323-
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
324-
//}
325-
aggName := strings.ToLower(agg.String())
326-
gf := gb.getAggRef(aggName)
327-
if gf != nil {
328-
// if we've already computed use reference here
329-
return gf
330-
}
331-
332-
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
333-
id := gb.outScope.newColumn(col)
334-
335-
agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
336-
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
337-
col.scalar = agg
338-
339-
col.id = id
340-
gb.addAggStr(col)
341-
return col.scalarGf()
293+
return b.buildJsonArrayStarAggregate(gb)
342294
}
343295
}
344296

345297
if strings.EqualFold(name, "any_value") {
346298
b.qFlags.Set(sql.QFlagAnyAgg)
347299
}
348300

349-
var args []sql.Expression
350-
for _, arg := range e.Exprs {
351-
e := b.selectExprToExpression(inScope, arg)
352-
switch e := e.(type) {
353-
case *expression.GetField:
354-
if e.TableId() == 0 {
355-
// TODO: not sure where this came from but it's not true
356-
// aliases are not valid aggregate arguments, the alias must be masking a column
357-
gf := b.selectExprToExpression(inScope.parent, arg)
358-
var ok bool
359-
e, ok = gf.(*expression.GetField)
360-
if !ok || e.TableId() == 0 {
361-
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
362-
}
363-
}
364-
args = append(args, e)
365-
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
366-
gb.addInCol(col)
367-
case *expression.Star:
368-
err := sql.ErrStarUnsupported.New()
369-
b.handleErr(err)
370-
case *plan.Subquery:
371-
args = append(args, e)
372-
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
373-
gb.addInCol(col)
374-
default:
375-
args = append(args, e)
376-
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
377-
gb.addInCol(col)
378-
}
301+
args := b.buildAggFunctionArgs(inScope, e, gb)
302+
agg := b.newAggregation(e, name, args)
303+
304+
if name == "count" {
305+
b.qFlags.Set(sql.QFlagCount)
379306
}
380307

308+
aggType := agg.Type()
309+
if name == "avg" || name == "sum" {
310+
aggType = types.Float64
311+
}
312+
313+
aggName := strings.ToLower(plan.AliasSubqueryString(agg))
314+
if id, ok := gb.outScope.getExpr(aggName, true); ok {
315+
// if we've already computed use reference here
316+
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
317+
return gf
318+
}
319+
320+
col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
321+
id := gb.outScope.newColumn(col)
322+
323+
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
324+
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
325+
col.scalar = agg
326+
327+
col.id = id
328+
gb.addAggStr(col)
329+
return col.scalarGf()
330+
}
331+
332+
// newAggregation creates a new aggregation function instanc from the arguments given
333+
func (b *Builder) newAggregation(e *ast.FuncExpr, name string, args []sql.Expression) sql.Aggregation {
381334
var agg sql.Aggregation
382335
if e.Distinct && name == "count" {
383336
agg = aggregation.NewCountDistinct(args...)
384337
} else {
385-
386338
// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
387339
// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
388340
if e.Distinct {
@@ -412,39 +364,104 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
412364
b.handleErr(err)
413365
}
414366
}
367+
return agg
368+
}
415369

416-
if name == "count" {
417-
b.qFlags.Set(sql.QFlagCount)
370+
// buildAggFunctionArgs builds the arguments for an aggregate function
371+
func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *groupBy) []sql.Expression {
372+
var args []sql.Expression
373+
for _, arg := range e.Exprs {
374+
e := b.selectExprToExpression(inScope, arg)
375+
switch e := e.(type) {
376+
case *expression.GetField:
377+
if e.TableId() == 0 {
378+
// TODO: not sure where this came from but it's not true
379+
// aliases are not valid aggregate arguments, the alias must be masking a column
380+
gf := b.selectExprToExpression(inScope.parent, arg)
381+
var ok bool
382+
e, ok = gf.(*expression.GetField)
383+
if !ok || e.TableId() == 0 {
384+
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
385+
}
386+
}
387+
args = append(args, e)
388+
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
389+
gb.addInCol(col)
390+
case *expression.Star:
391+
err := sql.ErrStarUnsupported.New()
392+
b.handleErr(err)
393+
case *plan.Subquery:
394+
args = append(args, e)
395+
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
396+
gb.addInCol(col)
397+
default:
398+
args = append(args, e)
399+
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
400+
gb.addInCol(col)
401+
}
418402
}
403+
return args
404+
}
419405

420-
aggType := agg.Type()
421-
if name == "avg" || name == "sum" {
422-
aggType = types.Float64
406+
// buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
407+
func (b *Builder) buildJsonArrayStarAggregate(gb *groupBy) sql.Expression {
408+
var agg sql.Aggregation
409+
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
410+
b.qFlags.Set(sql.QFlagStar)
411+
412+
// if e.Distinct {
413+
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
414+
// }
415+
aggName := strings.ToLower(agg.String())
416+
gf := gb.getAggRef(aggName)
417+
if gf != nil {
418+
// if we've already computed use reference here
419+
return gf
423420
}
424421

425-
aggName := strings.ToLower(plan.AliasSubqueryString(agg))
426-
if id, ok := gb.outScope.getExpr(aggName, true); ok {
422+
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
423+
id := gb.outScope.newColumn(col)
424+
425+
agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
426+
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
427+
col.scalar = agg
428+
429+
col.id = id
430+
gb.addAggStr(col)
431+
return col.scalarGf()
432+
}
433+
434+
// buildCountStarAggregate builds a COUNT(*) aggregate function
435+
func (b *Builder) buildCountStarAggregate(e *ast.FuncExpr, gb *groupBy) sql.Expression {
436+
var agg sql.Aggregation
437+
if e.Distinct {
438+
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
439+
} else {
440+
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
441+
}
442+
b.qFlags.Set(sql.QFlagCountStar)
443+
aggName := strings.ToLower(agg.String())
444+
gf := gb.getAggRef(aggName)
445+
if gf != nil {
427446
// if we've already computed use reference here
428-
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
429447
return gf
430448
}
431449

432-
col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
450+
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
433451
id := gb.outScope.newColumn(col)
452+
col.id = id
434453

435454
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
436455
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
437456
col.scalar = agg
438457

439-
col.id = id
440458
gb.addAggStr(col)
441459
return col.scalarGf()
442460
}
443461

462+
// buildGroupConcat builds a GROUP_CONCAT aggregate function
444463
func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression {
445-
if inScope.groupBy == nil {
446-
inScope.initGroupBy()
447-
}
464+
inScope.initGroupBy()
448465
gb := inScope.groupBy
449466

450467
args := make([]sql.Expression, len(e.Exprs))
@@ -794,7 +811,7 @@ func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where)
794811
return false, nil
795812
case *ast.FuncExpr:
796813
name := n.Name.Lowered()
797-
if isAggregateFunc(name) {
814+
if IsAggregateFunc(name) {
798815
// record aggregate
799816
// TODO: this should get projScope as well
800817
_ = b.buildAggregateFunc(fromScope, name, n)
@@ -874,9 +891,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
874891
if having == nil {
875892
return
876893
}
877-
if fromScope.groupBy == nil {
878-
fromScope.initGroupBy()
879-
}
894+
fromScope.initGroupBy()
880895

881896
havingScope := b.newScope()
882897
if fromScope.parent != nil {

sql/planbuilder/scalar.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
153153
return b.buildNameConst(inScope, v)
154154
} else if name == "icu_version" {
155155
return expression.NewLiteral(icuVersion, types.MustCreateString(query.Type_VARCHAR, int64(len(icuVersion)), sql.Collation_Default))
156-
} else if isAggregateFunc(name) && v.Over == nil {
156+
} else if IsAggregateFunc(name) && v.Over == nil {
157157
// TODO this assumes aggregate is in the same scope
158158
// also need to avoid nested aggregates
159159
return b.buildAggregateFunc(inScope, name, v)

sql/planbuilder/scope.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,9 @@ func (s *scope) initProc() {
224224
// initGroupBy creates a container scope for aggregation
225225
// functions and function inputs.
226226
func (s *scope) initGroupBy() {
227-
s.groupBy = &groupBy{outScope: s.replace()}
227+
if s.groupBy == nil {
228+
s.groupBy = &groupBy{outScope: s.replace()}
229+
}
228230
}
229231

230232
// pushSubquery creates a new scope with the subquery already initialized.

sql/planbuilder/show.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression {
614614
return expression.NewLiteral(v.String(), types.LongText)
615615
case *ast.FuncExpr:
616616
// todo(max): more specific validation for nested ASOF functions
617-
if isWindowFunc(v.Name.Lowered()) || isAggregateFunc(v.Name.Lowered()) {
617+
if isWindowFunc(v.Name.Lowered()) || IsAggregateFunc(v.Name.Lowered()) {
618618
err := sql.ErrInvalidAsOfExpression.New(v)
619619
b.handleErr(err)
620620
}

0 commit comments

Comments
 (0)