From 03ab22130e35463dabcbc2107bb27799d47ed92b Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Thu, 20 Feb 2025 09:24:29 -0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20override=20builtin=20functions=20(#?= =?UTF-8?q?5156)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Today, there is no official way to override builtin functions like `length`, this change tries to do the most minimal change in MQL to allow a unique pattern for providers to override builtin functions. The core change is an additional check when executing a bound function. We check if the resource defines the builtin function like `length`, and if so, we prioritize it. If the provider doesn't define it, we execute the function as we did before, by loading first the builtin function and if not existing, then we run the resource function. Here is an example of this pattern to override the `length` builtin function. Having a resource that exposes a function that loads big amounts of resources: ``` cloud { resources() []resources } ``` The provider implementation would look something like this: ```go func (c *mqlCloud) resources() ([]interface, error) { // API call to fetch all resources } ``` When running the MQL query `cloud.resources.length`, we first load all the resources, and then we count them. Since we do not need any information about the resources themselves, this could delay policies if we do things like: ``` cloud.resources.length < 5 && cloud.resources.length > 1 ``` An alternative to improve these resources is to override the builtin function with a more performant implementation. ``` cloud { // update function with a new custom list resource resources() customResources } // definition of a custom list resource customResources { []resource // overrides builtin function length() int } ``` The new implementation moves the logic that loads the resources into a `list()` function and exposes a new function that overrides the builtin function: would look like: ```go func (c *mqlCloud) resources() (*mqlCustomResources, error) { // Only initializes the custom list resource } func (c *mqlCustomResources) list() ([]interface, error) { // API call to fetch all resources } // length() overrides the builtin function, func (c *mqlCustomResources) length() (int64, error) { // This should be a more performant way to count the "resources" } ``` Additionally, this change moves the `findField()` function from the `compiler` to the resource `Schema`. --------- Signed-off-by: Salim Afiune Maya --- llx/builtin.go | 12 ++ mql/mql_test.go | 22 +++ mqlc/builtin_resource.go | 2 +- mqlc/mqlc.go | 37 +---- providers-sdk/v1/resources/schema.go | 39 +++++ .../testutils/mockprovider/resources/all.go | 40 +++++ .../mockprovider/resources/mockprovider.lr | 13 ++ .../mockprovider/resources/mockprovider.lr.go | 156 ++++++++++++++++++ providers/extensible_schema.go | 19 +++ providers/extensible_schema_test.go | 9 +- 10 files changed, 311 insertions(+), 38 deletions(-) diff --git a/llx/builtin.go b/llx/builtin.go index 6e30031eb7..87d479d935 100644 --- a/llx/builtin.go +++ b/llx/builtin.go @@ -862,6 +862,17 @@ func BuiltinFunctionV2(typ types.Type, name string) (*chunkHandlerV2, error) { func (e *blockExecutor) runBoundFunction(bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) { log.Trace().Uint64("ref", ref).Str("id", chunk.Id).Msg("exec> run bound function") + // check if the resource defines the function to allow providers to override + // builtin functions like `length` or any other function + if bind.Type.IsResource() && bind.Value != nil { + rr := bind.Value.(Resource) + resource := e.ctx.runtime.Schema().Lookup(rr.MqlName()) + _, _, override := e.ctx.runtime.Schema().FindField(resource, chunk.Id) + if override { + return runResourceFunction(e, bind, chunk, ref) + } + } + fh, err := BuiltinFunctionV2(bind.Type, chunk.Id) if err == nil { res, dref, err := fh.f(e, bind, chunk, ref) @@ -879,5 +890,6 @@ func (e *blockExecutor) runBoundFunction(bind *RawData, chunk *Chunk, ref uint64 if bind.Type.IsResource() { return runResourceFunction(e, bind, chunk, ref) } + return nil, 0, err } diff --git a/mql/mql_test.go b/mql/mql_test.go index 38f730fe7e..b4935162bd 100644 --- a/mql/mql_test.go +++ b/mql/mql_test.go @@ -350,3 +350,25 @@ func TestDictMethods(t *testing.T) { }, }) } + +func TestBuiltinFunctionOverride(t *testing.T) { + x := testutils.InitTester(testutils.LinuxMock()) + x.TestSimple(t, []testutils.SimpleTest{ + // This access the resource length property, + // which overrides the builtin function `length` + { + Code: "mos.groups.length", + ResultIndex: 0, Expectation: int64(5), + }, + // This calls the native builtin `length` function + { + Code: "mos.groups.list.length", + ResultIndex: 0, Expectation: int64(7), + }, + // Same here, builtin `length` function + { + Code: "muser.groups.length", + ResultIndex: 0, Expectation: int64(2), + }, + }) +} diff --git a/mqlc/builtin_resource.go b/mqlc/builtin_resource.go index 6f00e5570c..a8aac87008 100644 --- a/mqlc/builtin_resource.go +++ b/mqlc/builtin_resource.go @@ -35,7 +35,7 @@ func compileResourceDefault(c *compiler, typ types.Type, ref uint64, id string, } } - fieldPath, fieldinfos, ok := c.findField(resource, id) + fieldPath, fieldinfos, ok := c.Schema.FindField(resource, id) if !ok { addFieldSuggestions(publicFieldsInfo(c, resource), id, c.Result) return "", errors.New("cannot find field '" + id + "' in resource " + resource.Name) diff --git a/mqlc/mqlc.go b/mqlc/mqlc.go index 047749a548..92fbee2ab7 100644 --- a/mqlc/mqlc.go +++ b/mqlc/mqlc.go @@ -891,41 +891,6 @@ func filterEmptyExpressions(expressions []*parser.Expression) []*parser.Expressi return res } -type fieldPath []string - -// TODO: embed this into the Schema LookupField call! -func (c *compiler) findField(resource *resources.ResourceInfo, fieldName string) (fieldPath, []*resources.Field, bool) { - fieldInfo, ok := resource.Fields[fieldName] - if ok { - return fieldPath{fieldName}, []*resources.Field{fieldInfo}, true - } - - for _, f := range resource.Fields { - if f.IsEmbedded { - typ := types.Type(f.Type) - nextResource := c.Schema.Lookup(typ.ResourceName()) - if nextResource == nil { - continue - } - childFieldPath, childFieldInfos, ok := c.findField(nextResource, fieldName) - if ok { - fp := make(fieldPath, len(childFieldPath)+1) - fieldInfos := make([]*resources.Field, len(childFieldPath)+1) - fp[0] = f.Name - fieldInfos[0] = f - for i, n := range childFieldPath { - fp[i+1] = n - } - for i, f := range childFieldInfos { - fieldInfos[i+1] = f - } - return fp, fieldInfos, true - } - } - } - return nil, nil, false -} - // compile a bound identifier to its binding // example: user { name } , where name is compiled bound to the user // it will return false if it cannot bind the identifier @@ -942,7 +907,7 @@ func (c *compiler) compileBoundIdentifierWithMqlCtx(id string, binding *variable return true, types.Nil, errors.New("cannot find resource that is called by '" + id + "' of type " + typ.Label()) } - fieldPath, fieldinfos, ok := c.findField(resource, id) + fieldPath, fieldinfos, ok := c.Schema.FindField(resource, id) if ok { fieldinfo := fieldinfos[len(fieldinfos)-1] c.CompilerConfig.Stats.CallField(resource.Name, fieldinfo) diff --git a/providers-sdk/v1/resources/schema.go b/providers-sdk/v1/resources/schema.go index 1e37aa0126..1c69716ab4 100644 --- a/providers-sdk/v1/resources/schema.go +++ b/providers-sdk/v1/resources/schema.go @@ -3,9 +3,14 @@ package resources +import ( + "go.mondoo.com/cnquery/v11/types" +) + type ResourcesSchema interface { Lookup(resource string) *ResourceInfo LookupField(resource string, field string) (*ResourceInfo, *Field) + FindField(resource *ResourceInfo, field string) (FieldPath, []*Field, bool) AllResources() map[string]*ResourceInfo } @@ -121,6 +126,40 @@ func (s *Schema) LookupField(resource string, field string) (*ResourceInfo, *Fie return res, res.Fields[field] } +type FieldPath []string + +func (s *Schema) FindField(resource *ResourceInfo, field string) (FieldPath, []*Field, bool) { + fieldInfo, ok := resource.Fields[field] + if ok { + return FieldPath{field}, []*Field{fieldInfo}, true + } + + for _, f := range resource.Fields { + if f.IsEmbedded { + typ := types.Type(f.Type) + nextResource := s.Lookup(typ.ResourceName()) + if nextResource == nil { + continue + } + childFieldPath, childFieldInfos, ok := s.FindField(nextResource, field) + if ok { + fp := make(FieldPath, len(childFieldPath)+1) + fieldInfos := make([]*Field, len(childFieldPath)+1) + fp[0] = f.Name + fieldInfos[0] = f + for i, n := range childFieldPath { + fp[i+1] = n + } + for i, f := range childFieldInfos { + fieldInfos[i+1] = f + } + return fp, fieldInfos, true + } + } + } + return nil, nil, false +} + func (s *Schema) AllResources() map[string]*ResourceInfo { return s.Resources } diff --git a/providers-sdk/v1/testutils/mockprovider/resources/all.go b/providers-sdk/v1/testutils/mockprovider/resources/all.go index 94de802b20..240c83acb6 100644 --- a/providers-sdk/v1/testutils/mockprovider/resources/all.go +++ b/providers-sdk/v1/testutils/mockprovider/resources/all.go @@ -4,6 +4,8 @@ package resources import ( + "fmt" + "go.mondoo.com/cnquery/v11/llx" "go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin" ) @@ -56,3 +58,41 @@ func (c *mqlMuser) dict() (any, error) { "string2": "👋", }, nil } + +// This is an example of how we can override builtin functions today, this will have to change to provide +// a better mechanism to do so but for now, this pattern is adopted in multiple providers + +// The example overrides the `length` builtin function by creating a custom list resource which +// essentially defers the loading of the actual "groups" (for this example) and provides a new function +// `length` that returns the number of "groups" but in a more performant way. + +// groups() just initializes the custom list resource +func (c *mqlMos) groups() (*mqlCustomGroups, error) { + mqlResource, err := CreateResource(c.MqlRuntime, "customGroups", map[string]*llx.RawData{}) + return mqlResource.(*mqlCustomGroups), err +} + +// list() is where we actually load the real resources, which could be slow in big environments +func (c *mqlCustomGroups) list() ([]interface{}, error) { + res := []interface{}{} + for i := 0; i < 7; i++ { + group, err := CreateResource(c.MqlRuntime, "mgroup", map[string]*llx.RawData{ + "name": llx.StringData(fmt.Sprintf("group%d", i+1)), + }) + if err != nil { + return res, err + } + res = append(res, group) + } + return res, nil +} + +// length() overrides the builtin function, this should be a more performant way to count +// the "groups" +// +// NOTE this length here is different from the builtin one just for testing +func (c *mqlCustomGroups) length() (int64, error) { + // use `c.MqlRuntime.Connection` to get the provider connection + // make performant API call to count resources + return 5, nil +} diff --git a/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr b/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr index 56353f7354..549a363162 100644 --- a/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr +++ b/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr @@ -16,3 +16,16 @@ muser { mgroup { name string } + +mos { + // example override builtin func + groups() customGroups +} + +// definition of custom list resource +customGroups { + []mgroup + + // overrides builtin function + length() int +} diff --git a/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr.go b/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr.go index a64395ac89..ee3ab55ebb 100644 --- a/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr.go +++ b/providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr.go @@ -25,6 +25,14 @@ func init() { // to override args, implement: initMgroup(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[string]*llx.RawData, plugin.Resource, error) Create: createMgroup, }, + "mos": { + // to override args, implement: initMos(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[string]*llx.RawData, plugin.Resource, error) + Create: createMos, + }, + "customGroups": { + // to override args, implement: initCustomGroups(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[string]*llx.RawData, plugin.Resource, error) + Create: createCustomGroups, + }, } } @@ -114,6 +122,15 @@ var getDataFields = map[string]func(r plugin.Resource) *plugin.DataRes{ "mgroup.name": func(r plugin.Resource) *plugin.DataRes { return (r.(*mqlMgroup).GetName()).ToDataRes(types.String) }, + "mos.groups": func(r plugin.Resource) *plugin.DataRes { + return (r.(*mqlMos).GetGroups()).ToDataRes(types.Resource("customGroups")) + }, + "customGroups.length": func(r plugin.Resource) *plugin.DataRes { + return (r.(*mqlCustomGroups).GetLength()).ToDataRes(types.Int) + }, + "customGroups.list": func(r plugin.Resource) *plugin.DataRes { + return (r.(*mqlCustomGroups).GetList()).ToDataRes(types.Array(types.Resource("mgroup"))) + }, } func GetData(resource plugin.Resource, field string, args map[string]*llx.RawData) *plugin.DataRes { @@ -162,6 +179,26 @@ var setDataFields = map[string]func(r plugin.Resource, v *llx.RawData) bool { r.(*mqlMgroup).Name, ok = plugin.RawToTValue[string](v.Value, v.Error) return }, + "mos.__id": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlMos).__id, ok = v.Value.(string) + return + }, + "mos.groups": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlMos).Groups, ok = plugin.RawToTValue[*mqlCustomGroups](v.Value, v.Error) + return + }, + "customGroups.__id": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlCustomGroups).__id, ok = v.Value.(string) + return + }, + "customGroups.length": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlCustomGroups).Length, ok = plugin.RawToTValue[int64](v.Value, v.Error) + return + }, + "customGroups.list": func(r plugin.Resource, v *llx.RawData) (ok bool) { + r.(*mqlCustomGroups).List, ok = plugin.RawToTValue[[]interface{}](v.Value, v.Error) + return + }, } func SetData(resource plugin.Resource, field string, val *llx.RawData) error { @@ -348,3 +385,122 @@ func (c *mqlMgroup) MqlID() string { func (c *mqlMgroup) GetName() *plugin.TValue[string] { return &c.Name } + +// mqlMos for the mos resource +type mqlMos struct { + MqlRuntime *plugin.Runtime + __id string + // optional: if you define mqlMosInternal it will be used here + Groups plugin.TValue[*mqlCustomGroups] +} + +// createMos creates a new instance of this resource +func createMos(runtime *plugin.Runtime, args map[string]*llx.RawData) (plugin.Resource, error) { + res := &mqlMos{ + MqlRuntime: runtime, + } + + err := SetAllData(res, args) + if err != nil { + return res, err + } + + // to override __id implement: id() (string, error) + + if runtime.HasRecording { + args, err = runtime.ResourceFromRecording("mos", res.__id) + if err != nil || args == nil { + return res, err + } + return res, SetAllData(res, args) + } + + return res, nil +} + +func (c *mqlMos) MqlName() string { + return "mos" +} + +func (c *mqlMos) MqlID() string { + return c.__id +} + +func (c *mqlMos) GetGroups() *plugin.TValue[*mqlCustomGroups] { + return plugin.GetOrCompute[*mqlCustomGroups](&c.Groups, func() (*mqlCustomGroups, error) { + if c.MqlRuntime.HasRecording { + d, err := c.MqlRuntime.FieldResourceFromRecording("mos", c.__id, "groups") + if err != nil { + return nil, err + } + if d != nil { + return d.Value.(*mqlCustomGroups), nil + } + } + + return c.groups() + }) +} + +// mqlCustomGroups for the customGroups resource +type mqlCustomGroups struct { + MqlRuntime *plugin.Runtime + __id string + // optional: if you define mqlCustomGroupsInternal it will be used here + Length plugin.TValue[int64] + List plugin.TValue[[]interface{}] +} + +// createCustomGroups creates a new instance of this resource +func createCustomGroups(runtime *plugin.Runtime, args map[string]*llx.RawData) (plugin.Resource, error) { + res := &mqlCustomGroups{ + MqlRuntime: runtime, + } + + err := SetAllData(res, args) + if err != nil { + return res, err + } + + // to override __id implement: id() (string, error) + + if runtime.HasRecording { + args, err = runtime.ResourceFromRecording("customGroups", res.__id) + if err != nil || args == nil { + return res, err + } + return res, SetAllData(res, args) + } + + return res, nil +} + +func (c *mqlCustomGroups) MqlName() string { + return "customGroups" +} + +func (c *mqlCustomGroups) MqlID() string { + return c.__id +} + +func (c *mqlCustomGroups) GetLength() *plugin.TValue[int64] { + return plugin.GetOrCompute[int64](&c.Length, func() (int64, error) { + return c.length() + }) +} + +func (c *mqlCustomGroups) GetList() *plugin.TValue[[]interface{}] { + return plugin.GetOrCompute[[]interface{}](&c.List, func() ([]interface{}, error) { + if c.MqlRuntime.HasRecording { + d, err := c.MqlRuntime.FieldResourceFromRecording("customGroups", c.__id, "list") + if err != nil { + return nil, err + } + if d != nil { + return d.Value.([]interface{}), nil + } + } + + return c.list() + }) +} diff --git a/providers/extensible_schema.go b/providers/extensible_schema.go index 684325d1c3..f0b2a0a174 100644 --- a/providers/extensible_schema.go +++ b/providers/extensible_schema.go @@ -115,6 +115,25 @@ func (x *extensibleSchema) LookupField(resource string, field string) (*resource return x.roAggregate.LookupField(resource, field) } +func (x *extensibleSchema) FindField(resource *resources.ResourceInfo, field string) (resources.FieldPath, []*resources.Field, bool) { + x.sync.Lock() + defer x.sync.Unlock() + + filePath, fieldinfos, found := x.roAggregate.FindField(resource, field) + if found { + return filePath, fieldinfos, found + } + + if x.lastRefreshed >= LastProviderInstall { + return filePath, fieldinfos, found + } + + x.unsafeLoadAll() + x.unsafeRefresh() + + return x.roAggregate.FindField(resource, field) +} + // Prioritize the provider IDs in the order that is provided. Any other // provider comes later and in any random order. func (x *extensibleSchema) prioritizeIDs(prioritization ...string) { diff --git a/providers/extensible_schema_test.go b/providers/extensible_schema_test.go index 539954f9b1..8385bb7fe8 100644 --- a/providers/extensible_schema_test.go +++ b/providers/extensible_schema_test.go @@ -54,7 +54,14 @@ func TestExtensibleSchema(t *testing.T) { providers = []string{finfo.Provider, info.Others[0].Fields["iii"].Provider} assert.ElementsMatch(t, []string{"first", "second"}, providers) - _, finfo = s.LookupField("eternity", "v") + info, finfo = s.LookupField("eternity", "v") require.NotNil(t, info) assert.Equal(t, "first", finfo.Provider) + + // Find field from resource + filePath, fieldinfos, found := s.FindField(info, "v") + require.True(t, found) + require.Equal(t, resources.FieldPath{"v"}, filePath) + require.Len(t, fieldinfos, 1) + require.Equal(t, "first", fieldinfos[0].Provider) }