diff --git a/go.mod b/go.mod index 102f35342f..b1ba8e5ee6 100644 --- a/go.mod +++ b/go.mod @@ -148,6 +148,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect + github.com/pb33f/libopenapi-validator v0.0.42 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -165,6 +166,7 @@ require ( github.com/ryancurrah/gomodguard v1.3.0 // indirect github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect github.com/sanposhiho/wastedassign/v2 v2.0.7 // indirect + github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 // indirect github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.25.0 // indirect github.com/securego/gosec/v2 v2.19.0 // indirect diff --git a/go.sum b/go.sum index 5027a1ff12..d2b60a3c3e 100644 --- a/go.sum +++ b/go.sum @@ -479,6 +479,8 @@ github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT9 github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= github.com/pb33f/libopenapi v0.15.11 h1:xg2JWswvd6RUnV4akW8YOD5C4cA/GEnaZdY/B0MZBX8= github.com/pb33f/libopenapi v0.15.11/go.mod h1:PEXNwvtT4KNdjrwudp5OYnD1ryqK6uJ68aMNyWvoMuc= +github.com/pb33f/libopenapi-validator v0.0.42 h1:bfwPWlxUFHtvPNi0PH+EVpQBU2kA3Db9rVdFkfmUVac= +github.com/pb33f/libopenapi-validator v0.0.42/go.mod h1:kU1JYyXIRlpmsWx3NkL+drNNttLADMgdaNzJgXDhec0= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= @@ -534,6 +536,8 @@ github.com/ryanrolds/sqlclosecheck v0.5.1 h1:dibWW826u0P8jNLsLN+En7+RqWWTYrjCB9f github.com/ryanrolds/sqlclosecheck v0.5.1/go.mod h1:2g3dUjoS6AL4huFdv6wn55WpLIDjY7ZgUR4J8HOO/XQ= github.com/sanposhiho/wastedassign/v2 v2.0.7 h1:J+6nrY4VW+gC9xFzUc+XjPD3g3wF3je/NsJFwFK7Uxc= github.com/sanposhiho/wastedassign/v2 v2.0.7/go.mod h1:KyZ0MWTwxxBmfwn33zh3k1dmsbF2ud9pAAGfoLfjhtI= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY= github.com/sashamelentyev/interfacebloat v1.1.0 h1:xdRdJp0irL086OyW1H/RTZTr1h/tMEOsumirXcOJqAw= github.com/sashamelentyev/interfacebloat v1.1.0/go.mod h1:+Y9yU5YdTkrNvoX0xHc84dxiN1iBi9+G8zZIhPVoNjQ= github.com/sashamelentyev/usestdlibvars v1.25.0 h1:IK8SI2QyFzy/2OD2PYnhy84dpfNo9qADrRt6LH8vSzU= diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 1a93bc1b8f..b474c40eea 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -10,8 +10,8 @@ import ( "strings" "syscall" - libopenapi "github.com/pb33f/libopenapi/datamodel/high/base" "github.com/mitchellh/go-homedir" + libopenapi "github.com/pb33f/libopenapi/datamodel/high/base" "github.com/spf13/cobra" "github.com/vincent-petithory/dataurl" @@ -194,8 +194,8 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o } // Multiple outputs! if outputSchema.Type[0] == "array" && outputSchema.Items.A.Schema().Type != nil && outputSchema.Items.A.Schema().Type[0] == "string" && - outputSchema.Items.A.Schema().Format == "uri" { - return handleMultipleFileOutput(prediction, outputSchemaProxy) + outputSchema.Items.A.Schema().Format == "uri" { + return handleMultipleFileOutput(prediction, outputSchemaProxy) } if outputSchema.Type[0] == "string" && outputSchema.Format == "uri" { diff --git a/pkg/image/build.go b/pkg/image/build.go index 6748262a4a..fa183fbd44 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -9,6 +9,7 @@ import ( "path" "github.com/pb33f/libopenapi" + "github.com/pb33f/libopenapi-validator" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" @@ -139,6 +140,23 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, return fmt.Errorf("Model schema is invalid, %d errors reported\n\n%s\n\n%s", len(errors), errorString, string(schemaJSON)) } + docValidator, validatorErrs := validator.NewValidator(document) + if validatorErrs != nil { + return fmt.Errorf("Failed to create validator: %e", validatorErrs) + } + + valid, validationErrs := docValidator.ValidateDocument() + + if !valid { + errorString := "" + for _, e := range validationErrs { + errorString += fmt.Sprintf("Type: %s, Failure: %s\n", e.ValidationType, e.Message) + errorString += fmt.Sprintf("Fix: %s\n\n", e.HowToFix) + } + + return fmt.Errorf("Model schema doesn't match OpenAPI spec, %d errors reported\n\n%s\n\n%s", len(validationErrs), errorString, string(schemaJSON)) + } + console.Info("Adding labels to image...") // We used to set the cog_version and config labels in Dockerfile, because we didn't require running the diff --git a/pkg/image/openapi_schema.go b/pkg/image/openapi_schema.go index 333beb7eb6..2613cbe1a8 100644 --- a/pkg/image/openapi_schema.go +++ b/pkg/image/openapi_schema.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - "github.com/pb33f/libopenapi" + "github.com/pb33f/libopenapi" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/global" @@ -68,9 +68,9 @@ func GetOpenAPISchema(imageName string) (*libopenapi.Document, error) { if schemaString == "" { return nil, fmt.Errorf("Image %s does not appear to be a Cog model", imageName) } - document, error := libopenapi.NewDocument([]byte(schemaString)) - if error != nil { - return nil, error + document, err := libopenapi.NewDocument([]byte(schemaString)) + if err != nil { + return nil, err } _, errors := document.BuildV3Model() if len(errors) > 0 { diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index bfbda5ccd8..26fdbcce03 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -194,9 +194,9 @@ func (p *Predictor) GetSchema() (*libopenapi.Document, error) { if err != nil { return nil, err } - document, error := libopenapi.NewDocument(body) - if error != nil { - return nil, error + document, err := libopenapi.NewDocument(body) + if err != nil { + return nil, err } _, errors := document.BuildV3Model() if len(errors) > 0 {