Skip to content

Commit f741acd

Browse files
committed
Added WHEN support for triggers
1 parent 9f1def9 commit f741acd

File tree

6 files changed

+149
-16
lines changed

6 files changed

+149
-16
lines changed

core/triggers/collection.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ import (
2727
"github.com/dolthub/dolt/go/store/hash"
2828
"github.com/dolthub/dolt/go/store/prolly"
2929
"github.com/dolthub/dolt/go/store/prolly/tree"
30-
"github.com/dolthub/go-mysql-server/sql"
3130

3231
"github.com/dolthub/doltgresql/core/id"
3332
"github.com/dolthub/doltgresql/core/rootobject/objinterface"
33+
"github.com/dolthub/doltgresql/server/plpgsql"
3434
)
3535

3636
// Collection contains a collection of triggers.
@@ -83,8 +83,8 @@ type Trigger struct {
8383
Function id.Function
8484
Timing TriggerTiming
8585
Events []TriggerEvent
86-
ForEachRow bool // When false, represents FOR EACH STATEMENT
87-
When sql.Expression // TODO: should this be PLpgSQL operations?
86+
ForEachRow bool // When false, represents FOR EACH STATEMENT
87+
When []plpgsql.InterpreterOperation
8888
Deferrable TriggerDeferrable
8989
ReferencedTableName id.Table // FROM referenced_table_name
9090
Constraint bool

core/triggers/serialization.go

+24-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/cockroachdb/errors"
2121

2222
"github.com/dolthub/doltgresql/core/id"
23+
"github.com/dolthub/doltgresql/server/plpgsql"
2324
"github.com/dolthub/doltgresql/utils"
2425
)
2526

@@ -37,14 +38,23 @@ func (trigger Trigger) Serialize(ctx context.Context) ([]byte, error) {
3738
writer.Id(trigger.Function.AsId())
3839
writer.Uint8(uint8(trigger.Timing))
3940
writer.Bool(trigger.ForEachRow)
40-
// TODO: writer.Unknown(trigger.When)
4141
writer.Uint8(uint8(trigger.Deferrable))
4242
writer.Id(trigger.ReferencedTableName.AsId())
4343
writer.Bool(trigger.Constraint)
4444
writer.String(trigger.OldTransitionName)
4545
writer.String(trigger.NewTransitionName)
4646
writer.StringSlice(trigger.Arguments)
4747
writer.String(trigger.Definition)
48+
// Write the WHEN operations
49+
writer.VariableUint(uint64(len(trigger.When)))
50+
for _, op := range trigger.When {
51+
writer.Uint16(uint16(op.OpCode))
52+
writer.String(op.PrimaryData)
53+
writer.StringSlice(op.SecondaryData)
54+
writer.String(op.Target)
55+
writer.Int32(int32(op.Index))
56+
writer.StringMap(op.Options)
57+
}
4858
// Write the events
4959
writer.VariableUint(uint64(len(trigger.Events)))
5060
for _, event := range trigger.Events {
@@ -73,14 +83,26 @@ func DeserializeTrigger(ctx context.Context, data []byte) (Trigger, error) {
7383
t.Function = id.Function(reader.Id())
7484
t.Timing = TriggerTiming(reader.Uint8())
7585
t.ForEachRow = reader.Bool()
76-
// TODO: trigger.When = reader.Unknown()
7786
t.Deferrable = TriggerDeferrable(reader.Uint8())
7887
t.ReferencedTableName = id.Table(reader.Id())
7988
t.Constraint = reader.Bool()
8089
t.OldTransitionName = reader.String()
8190
t.NewTransitionName = reader.String()
8291
t.Arguments = reader.StringSlice()
8392
t.Definition = reader.String()
93+
// Read the WHEN operations
94+
opCount := reader.VariableUint()
95+
t.When = make([]plpgsql.InterpreterOperation, opCount)
96+
for opIdx := uint64(0); opIdx < opCount; opIdx++ {
97+
op := plpgsql.InterpreterOperation{}
98+
op.OpCode = plpgsql.OpCode(reader.Uint16())
99+
op.PrimaryData = reader.String()
100+
op.SecondaryData = reader.StringSlice()
101+
op.Target = reader.String()
102+
op.Index = int(reader.Int32())
103+
op.Options = reader.StringMap()
104+
t.When[opIdx] = op
105+
}
84106
// Read the events
85107
eventCount := reader.VariableUint()
86108
t.Events = make([]TriggerEvent, eventCount)

server/ast/create_trigger.go

+32-4
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@
1515
package ast
1616

1717
import (
18+
"fmt"
19+
"regexp"
20+
21+
"github.com/cockroachdb/errors"
1822
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
1923

2024
"github.com/dolthub/doltgresql/core/id"
2125
"github.com/dolthub/doltgresql/core/triggers"
22-
pgnodes "github.com/dolthub/doltgresql/server/node"
23-
2426
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
27+
pgnodes "github.com/dolthub/doltgresql/server/node"
28+
"github.com/dolthub/doltgresql/server/plpgsql"
2529
)
2630

31+
// createTriggerWhenCapture is a regex that should only capture the contents of the WHEN expression. Although a bit
32+
// complex, this is done to ensure that the capture group contains only the WHEN expression and nothing else.
33+
var createTriggerWhenCapture = regexp.MustCompile(`(?is)create\s+(?:or\s+replace\s+)?(?:constraint\s+)?trigger\s+.*\s+for\s+(?:each\s+)?(?:row|statement)\s+when\s+\((.*)\)\s+execute\s+(?:function|procedure).*`)
34+
2735
// nodeCreateTrigger handles *tree.CreateTrigger nodes.
28-
func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement, error) {
36+
func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (_ vitess.Statement, err error) {
2937
if node.Constraint {
3038
return NotYetSupportedError("CREATE CONSTRAINT TRIGGER is not yet supported")
3139
}
@@ -76,6 +84,26 @@ func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement
7684
return NotYetSupportedError("UNKNOWN EVENT TYPE is not yet supported for CREATE TRIGGER")
7785
}
7886
}
87+
// WHEN expressions seem to behave identically to interpreted functions, so we'll parse them as interpreted functions.
88+
// To do this, we need the raw string, and we wrap it as though it were a trigger function (which has special logic
89+
// for handling NEW and OLD rows). Using a regex for this rather than modifying the parser may seem suboptimal, but
90+
// we want to retain the parser validation of using an expression, however we cannot rely on the expression's
91+
// String() function to return the **exact** same string, so we capture it with a regex.
92+
var whenOps []plpgsql.InterpreterOperation
93+
if node.When != nil {
94+
matches := createTriggerWhenCapture.FindStringSubmatch(ctx.originalQuery)
95+
if len(matches) != 2 {
96+
return nil, errors.New("unable to parse WHEN expression from CREATE TRIGGER")
97+
}
98+
whenOps, err = plpgsql.Parse(fmt.Sprintf(`CREATE FUNCTION when_wrapper() RETURNS TRIGGER AS $$
99+
BEGIN
100+
RETURN %s;
101+
END;
102+
$$ LANGUAGE plpgsql;`, matches[1]))
103+
if err != nil {
104+
return nil, err
105+
}
106+
}
79107
return vitess.InjectedStatement{
80108
Statement: pgnodes.NewCreateTrigger(
81109
id.NewTrigger(node.OnTable.Schema(), node.OnTable.Table(), node.Name.String()),
@@ -84,7 +112,7 @@ func nodeCreateTrigger(ctx *Context, node *tree.CreateTrigger) (vitess.Statement
84112
timing,
85113
events,
86114
node.ForEachRow,
87-
nil, // TODO: node.When (expr)
115+
whenOps,
88116
node.Args.ToStrings(),
89117
ctx.originalQuery,
90118
),

server/node/create_trigger.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import (
2121
"github.com/dolthub/go-mysql-server/sql/plan"
2222
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2323

24+
"github.com/dolthub/doltgresql/core"
2425
"github.com/dolthub/doltgresql/core/functions"
26+
"github.com/dolthub/doltgresql/core/id"
2527
"github.com/dolthub/doltgresql/core/triggers"
28+
"github.com/dolthub/doltgresql/server/plpgsql"
2629
pgtypes "github.com/dolthub/doltgresql/server/types"
27-
28-
"github.com/dolthub/doltgresql/core"
29-
"github.com/dolthub/doltgresql/core/id"
3030
)
3131

3232
// CreateTrigger implements CREATE TRIGGER.
@@ -37,7 +37,7 @@ type CreateTrigger struct {
3737
Timing triggers.TriggerTiming
3838
Events []triggers.TriggerEvent
3939
ForEachRow bool
40-
When sql.Expression
40+
When []plpgsql.InterpreterOperation
4141
Arguments []string
4242
Definition string
4343
}
@@ -53,7 +53,7 @@ func NewCreateTrigger(
5353
timing triggers.TriggerTiming,
5454
events []triggers.TriggerEvent,
5555
forEachRow bool,
56-
when sql.Expression,
56+
when []plpgsql.InterpreterOperation,
5757
arguments []string,
5858
definition string) *CreateTrigger {
5959
return &CreateTrigger{
@@ -125,7 +125,7 @@ func (c *CreateTrigger) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error
125125
Timing: c.Timing,
126126
Events: c.Events,
127127
ForEachRow: c.ForEachRow,
128-
When: nil,
128+
When: c.When,
129129
Deferrable: triggers.TriggerDeferrable_NotDeferrable,
130130
ReferencedTableName: "",
131131
Constraint: false,

server/node/trigger_execution.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package node
1616

1717
import (
1818
"fmt"
19+
"strings"
1920

2021
"github.com/cockroachdb/errors"
2122
"github.com/dolthub/go-mysql-server/sql"
@@ -83,14 +84,24 @@ func (te *TriggerExecution) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, e
8384
return sourceIter, nil
8485
}
8586
trigFuncs := make([]framework.InterpretedFunction, len(te.Triggers))
87+
whens := make([]framework.InterpretedFunction, len(te.Triggers))
8688
for i, trig := range te.Triggers {
8789
trigFuncs[i], err = te.loadTriggerFunction(ctx, trig)
8890
if err != nil {
8991
return nil, err
9092
}
93+
// If we have a WHEN expression, then we need to build a "function" to execute the expression
94+
if len(trig.When) > 0 {
95+
whens[i] = framework.InterpretedFunction{
96+
ID: trigFuncs[i].ID, // Assign the same ID just so we have a valid one for later
97+
ReturnType: pgtypes.Bool,
98+
Statements: trig.When,
99+
}
100+
}
91101
}
92102
return &triggerExecutionIter{
93103
functions: trigFuncs,
104+
whens: whens,
94105
split: te.Split,
95106
treturn: te.Return,
96107
runner: te.Runner.Runner,
@@ -156,6 +167,7 @@ func (te *TriggerExecution) loadTriggerFunction(ctx *sql.Context, trigger trigge
156167
// triggerExecutionIter is the iterator for TriggerExecution.
157168
type triggerExecutionIter struct {
158169
functions []framework.InterpretedFunction
170+
whens []framework.InterpretedFunction
159171
split TriggerExecutionRowHandling
160172
treturn TriggerExecutionRowHandling
161173
runner analyzer.StatementRunner
@@ -185,7 +197,25 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
185197
case TriggerExecutionRowHandling_New:
186198
newRow = nextRow
187199
}
188-
for _, function := range t.functions {
200+
for funcIdx, function := range t.functions {
201+
if t.whens[funcIdx].ID.IsValid() {
202+
whenValue, err := plpgsql.TriggerCall(ctx, t.whens[funcIdx], t.runner, t.sch, oldRow, newRow)
203+
if err != nil {
204+
if strings.Contains(err.Error(), "no valid cast for return value") {
205+
// TODO: this error should technically be caught during parsing, but interpreted functions don't
206+
// have the ability to determine types during parsing yet (also applies to the same error below)
207+
return nil, fmt.Errorf("argument of WHEN must be type boolean")
208+
}
209+
return nil, err
210+
}
211+
whenBool, ok := whenValue.(bool)
212+
if !ok {
213+
return nil, fmt.Errorf("argument of WHEN must be type boolean")
214+
}
215+
if !whenBool {
216+
continue
217+
}
218+
}
189219
returnedValue, err := plpgsql.TriggerCall(ctx, function, t.runner, t.sch, oldRow, newRow)
190220
if err != nil {
191221
return nil, err

testing/go/trigger_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -586,5 +586,58 @@ $$ LANGUAGE plpgsql;`,
586586
},
587587
},
588588
},
589+
{
590+
Name: "WHEN on BEFORE INSERT",
591+
SetUpScript: []string{
592+
"CREATE TABLE test (pk INT PRIMARY KEY, v1 TEXT);",
593+
`CREATE FUNCTION trigger_func1() RETURNS TRIGGER AS $$
594+
BEGIN
595+
NEW.v1 := NEW.pk::text || '_' || NEW.v1;
596+
RETURN NEW;
597+
END;
598+
$$ LANGUAGE plpgsql;`,
599+
`CREATE FUNCTION trigger_func2() RETURNS TRIGGER AS $$
600+
BEGIN
601+
NEW.v1 := NEW.v1 || '_' || NEW.pk::text;
602+
RETURN NEW;
603+
END;
604+
$$ LANGUAGE plpgsql;`,
605+
`CREATE TRIGGER test_trigger1 BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk < 1) EXECUTE FUNCTION trigger_func1();`,
606+
`CREATE TRIGGER test_trigger2 BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk > 1) EXECUTE FUNCTION trigger_func2();`,
607+
},
608+
Assertions: []ScriptTestAssertion{
609+
{
610+
Query: "INSERT INTO test VALUES (0, 'hi'), (1, 'there'), (2, 'dude');",
611+
Expected: []sql.Row{},
612+
},
613+
{
614+
Query: "SELECT * FROM test;",
615+
Expected: []sql.Row{
616+
{0, "0_hi"},
617+
{1, "there"},
618+
{2, "dude_2"},
619+
},
620+
},
621+
},
622+
},
623+
{
624+
Name: "WHEN with non-boolean expression",
625+
SetUpScript: []string{
626+
"CREATE TABLE test (pk INT PRIMARY KEY, v1 TEXT);",
627+
`CREATE FUNCTION trigger_func() RETURNS TRIGGER AS $$
628+
BEGIN
629+
NEW.v1 := NEW.pk::text || '_' || NEW.v1;
630+
RETURN NEW;
631+
END;
632+
$$ LANGUAGE plpgsql;`,
633+
`CREATE TRIGGER test_trigger BEFORE INSERT ON test FOR EACH ROW WHEN (NEW.pk + 1) EXECUTE FUNCTION trigger_func();`,
634+
},
635+
Assertions: []ScriptTestAssertion{
636+
{
637+
Query: "INSERT INTO test VALUES (1, 'hi'), (2, 'there');",
638+
ExpectedErr: "argument of WHEN must be type boolean",
639+
},
640+
},
641+
},
589642
})
590643
}

0 commit comments

Comments
 (0)