diff --git a/cmd/handle_http_traffic.go b/cmd/handle_http_traffic.go index f45b300..502fc19 100644 --- a/cmd/handle_http_traffic.go +++ b/cmd/handle_http_traffic.go @@ -27,12 +27,27 @@ func handleHttpTraffic(wiretapConfig *shared.WiretapConfiguration, wtService *da wtService.HandleHttpRequest(requestModel) } + handleWebsocket := func(w http.ResponseWriter, r *http.Request) { + id, _ := uuid.NewUUID() + requestModel := &model.Request{ + Id: &id, + HttpRequest: r, + HttpResponseWriter: w, + } + wtService.HandleWebsocketRequest(requestModel) + } + // create a new mux. mux := http.NewServeMux() // handle the index mux.HandleFunc("/", handleTraffic) + // Handle Websockets + for websocket := range wiretapConfig.WebsocketConfigs { + mux.HandleFunc(websocket, handleWebsocket) + } + pterm.Info.Println(pterm.LightMagenta(fmt.Sprintf("API Gateway UI booting on port %s...", wiretapConfig.Port))) var httpErr error diff --git a/cmd/root_command.go b/cmd/root_command.go index a47d497..9a93c82 100644 --- a/cmd/root_command.go +++ b/cmd/root_command.go @@ -359,6 +359,16 @@ var ( printLoadedValidationAllowList(config.ValidationAllowList) } + if len(config.WebsocketConfigs) > 0 { + for _, config := range config.WebsocketConfigs { + if config.VerifyCert == nil { + config.VerifyCert = func() *bool { b := true; return &b }() + } + } + + printLoadedWebsockets(config.WebsocketConfigs) + } + // static headers if config.Headers != nil && len(config.Headers.DropHeaders) > 0 { pterm.Info.Printf("Dropping the following %d %s globally:\n", len(config.Headers.DropHeaders), @@ -635,8 +645,7 @@ func Execute(version, commit, date string, fs embed.FS) { rootCmd.Flags().IntP("hard-validation-code", "q", 400, "Set a custom http error code for non-compliant requests when using the hard-error flag") rootCmd.Flags().IntP("hard-validation-return-code", "y", 502, "Set a custom http error code for non-compliant responses when using the hard-error flag") rootCmd.Flags().BoolP("mock-mode", "x", false, "Run in mock mode, responses are mocked and no traffic is sent to the target API (requires OpenAPI spec)") - rootCmd.Flags().StringP("config", "c", "", - "Location of wiretap configuration file to use (default is .wiretap in current directory)") + rootCmd.Flags().StringP("config", "c", "", "Location of wiretap configuration file to use (default is .wiretap in current directory)") rootCmd.Flags().StringP("base", "b", "", "Set a base path to resolve relative file references from, or a overriding base URL to resolve remote references from") rootCmd.Flags().BoolP("debug", "l", false, "Enable debug logging") rootCmd.Flags().StringP("har", "z", "", "Load a HAR file instead of sniffing traffic") @@ -716,7 +725,7 @@ func printLoadedIgnoreRedirectPaths(ignoreRedirects []string) { } func printLoadedRedirectAllowList(allowRedirects []string) { - pterm.Info.Printf("Loaded %d allow listed redirect %s :\n", len(allowRedirects), + pterm.Info.Printf("Loaded %d allow listed redirect %s:\n", len(allowRedirects), shared.Pluralize(len(allowRedirects), "path", "paths")) for _, x := range allowRedirects { @@ -725,6 +734,15 @@ func printLoadedRedirectAllowList(allowRedirects []string) { pterm.Println() } +func printLoadedWebsockets(websockets map[string]*shared.WiretapWebsocketConfig) { + pterm.Info.Printf("Loaded %d %s: \n", len(websockets), shared.Pluralize(len(websockets), "websocket", "websockets")) + + for websocket := range websockets { + pterm.Printf("🔌 Paths prefixed '%s' will be managed as a websocket\n", pterm.LightCyan(websocket)) + } + pterm.Println() +} + func printLoadedIgnoreValidationPaths(ignoreValidations []string) { pterm.Info.Printf("Loaded %d %s to ignore validation:\n", len(ignoreValidations), shared.Pluralize(len(ignoreValidations), "path", "paths")) diff --git a/daemon/handle_request.go b/daemon/handle_request.go index a1fc6fb..7e97412 100644 --- a/daemon/handle_request.go +++ b/daemon/handle_request.go @@ -4,8 +4,10 @@ package daemon import ( + "crypto/tls" _ "embed" "fmt" + "github.com/gorilla/websocket" "io" "net/http" "os" @@ -99,32 +101,7 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { } } - var dropHeaders []string - var injectHeaders map[string]string - - // add global headers with injection. - if config.Headers != nil { - dropHeaders = config.Headers.DropHeaders - injectHeaders = config.Headers.InjectHeaders - } - - // now add path specific headers. - matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) - auth := "" - if len(matchedPaths) > 0 { - for _, path := range matchedPaths { - auth = path.Auth - if path.Headers != nil { - dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) - newInjectHeaders := path.Headers.InjectHeaders - for key := range injectHeaders { - newInjectHeaders[key] = injectHeaders[key] - } - injectHeaders = newInjectHeaders - } - break - } - } + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) newReq := CloneExistingRequest(CloneRequest{ Request: request.HttpRequest, @@ -235,8 +212,210 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { _, _ = request.HttpResponseWriter.Write(body) } +var gorillaDropHeaders = []string{ + // Gorilla fills in the following headers, and complains if they are already present + "Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol", + "Sec-Websocket-Extensions", +} + +func (ws *WiretapService) handleWebsocketRequest(request *model.Request) { + + configStore, _ := ws.controlsStore.Get(shared.ConfigKey) + config := configStore.(*shared.WiretapConfiguration) + + // Get the Websocket Configuration + websocketUrl := request.HttpRequest.URL.String() + websocketConfig, ok := config.WebsocketConfigs[websocketUrl] + if !ok { + ws.config.Logger.Error(fmt.Sprintf("Unable to find websocket config for URL: %s", websocketUrl)) + } + + // There's nothing to do if we're in mock mode + if config.MockMode { + return + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // Upgrade the connection from the client to open a websocket connection + clientConn, err := upgrader.Upgrade(request.HttpResponseWriter, request.HttpRequest, nil) + if err != nil { + ws.config.Logger.Error("Unable to upgrade websocket connection") + return + } + defer func(clientConn *websocket.Conn) { + _ = clientConn.Close() + }(clientConn) + + if config.Headers == nil || len(config.Headers.DropHeaders) == 0 { + config.Headers = &shared.WiretapHeaderConfig{ + DropHeaders: []string{}, + } + } + + // Get the updated headers and auth + dropHeaders, injectHeaders, auth := ws.getHeadersAndAuth(config, request) + + dropHeaders = append(dropHeaders, gorillaDropHeaders...) + dropHeaders = append(dropHeaders, websocketConfig.DropHeaders...) + + // Determine the correct websocket protocol based on redirect protocol + var protocol string + if config.RedirectProtocol == "https" { + protocol = "wss" + } else if config.RedirectProtocol == "http" { + protocol = "ws" + } else if config.RedirectProtocol != "wss" && config.RedirectProtocol != "ws" { + config.Logger.Error(fmt.Sprintf("Unsupported Redirect Protocol: %s", config.RedirectProtocol)) + return + } + + // Create a new request, which fills in the URL and other information + newRequest := CloneExistingRequest(CloneRequest{ + Request: request.HttpRequest, + Protocol: protocol, + Host: config.RedirectHost, + BasePath: config.RedirectBasePath, + Port: config.RedirectPort, + DropHeaders: dropHeaders, + InjectHeaders: injectHeaders, + Auth: auth, + Variables: config.CompiledVariables, + }) + + // Open a new websocket connection with the server + dialer := *websocket.DefaultDialer + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: !*websocketConfig.VerifyCert} + serverConn, _, err := dialer.Dial(newRequest.URL.String(), newRequest.Header) + if err != nil { + ws.config.Logger.Error(fmt.Sprintf("Unable to connect to remote server; websocket connection failed: %s", err)) + return + } + defer func(serverConn *websocket.Conn) { + _ = serverConn.Close() + }(serverConn) + + // Create sentinel channels + clientSentinel := make(chan struct{}) + serverSentinel := make(chan struct{}) + + // Go-Routine for communication between Client -> Server + go func() { + defer close(clientSentinel) + + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = serverConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Go-Routine for communication between Server -> Client + go func() { + defer close(serverSentinel) + + for { + messageType, message, err := serverConn.ReadMessage() + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + _ = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + err = clientConn.WriteMessage(messageType, message) + if err != nil { + closeCode, isUnexpected := getCloseCode(err) + logWebsocketClose(config, closeCode, isUnexpected) + return + } + } + }() + + // Loop until at least one of our sentinel channels have been closed + for { + select { + case <-clientSentinel: + return + case <-serverSentinel: + return + } + } +} + func setCORSHeaders(headers map[string][]string) { headers["Access-Control-Allow-Headers"] = []string{"*"} headers["Access-Control-Allow-Origin"] = []string{"*"} headers["Access-Control-Allow-Methods"] = []string{"OPTIONS,POST,GET,DELETE,PATCH,PUT"} } + +func getCloseCode(err error) (int, bool) { + unexpectedClose := websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure, + ) + + if ce, ok := err.(*websocket.CloseError); ok { + return ce.Code, unexpectedClose + } + return -1, unexpectedClose +} + +func logWebsocketClose(config *shared.WiretapConfiguration, closeCode int, isUnexpected bool) { + if isUnexpected { + config.Logger.Warn(fmt.Sprintf("Websocket closed unexepectedly with code: %d", closeCode)) + } else { + config.Logger.Info(fmt.Sprintf("Websocket closed expectedly with code: %d", closeCode)) + } +} + +func (ws *WiretapService) getHeadersAndAuth(config *shared.WiretapConfiguration, request *model.Request) ([]string, map[string]string, string) { + var dropHeaders []string + var injectHeaders map[string]string + + // add global headers with injection. + if config.Headers != nil { + dropHeaders = config.Headers.DropHeaders + injectHeaders = config.Headers.InjectHeaders + } + + // now add path specific headers. + matchedPaths := configModel.FindPaths(request.HttpRequest.URL.Path, config) + auth := "" + if len(matchedPaths) > 0 { + for _, path := range matchedPaths { + auth = path.Auth + if path.Headers != nil { + dropHeaders = append(dropHeaders, path.Headers.DropHeaders...) + newInjectHeaders := path.Headers.InjectHeaders + for key := range injectHeaders { + newInjectHeaders[key] = injectHeaders[key] + } + injectHeaders = newInjectHeaders + } + break + } + } + + return dropHeaders, injectHeaders, auth +} diff --git a/daemon/wiretap_service.go b/daemon/wiretap_service.go index 416bfe0..1b1d7d0 100644 --- a/daemon/wiretap_service.go +++ b/daemon/wiretap_service.go @@ -95,6 +95,9 @@ func (ws *WiretapService) HandleServiceRequest(request *model.Request, core serv } func (ws *WiretapService) HandleHttpRequest(request *model.Request) { - ws.handleHttpRequest(request) } + +func (ws *WiretapService) HandleWebsocketRequest(request *model.Request) { + ws.handleWebsocketRequest(request) +} diff --git a/mock/mock_engine.go b/mock/mock_engine.go index 226cd40..7c9c132 100644 --- a/mock/mock_engine.go +++ b/mock/mock_engine.go @@ -248,7 +248,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, // check the request is valid against security requirements. err = rme.ValidateSecurity(request, operation) if err != nil { - mt, _ := rme.lookForResponseCodes(operation, request, []string{"401"}) + mt, _ := rme.findBestMediaTypeMatch(operation, request, []string{"401"}) if mt != nil { mock, mockErr := rme.mockEngine.GenerateMock(mt, rme.extractPreferred(request)) if mockErr != nil { @@ -275,7 +275,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, // validate the request against the document. _, validationErrors := rme.validator.ValidateHttpRequest(request) if len(validationErrors) > 0 { - mt, _ := rme.lookForResponseCodes(operation, request, []string{"422", "400"}) + mt, _ := rme.findBestMediaTypeMatch(operation, request, []string{"422", "400"}) if mt == nil { // no default, no valid response, inform use with a 500 return rme.buildErrorWithPayload( @@ -297,11 +297,25 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, } - // get the lowest success code - lo := rme.findLowestSuccessCode(operation) - - // find the lowest success code. - mt, noMT := rme.lookForResponseCodes(operation, request, []string{lo}) + preferred := rme.extractPreferred(request) + + var lo string + var mt *v3.MediaType + var noMT bool = true + + if preferred != "" { + // If an explicit preferred header is present, let it have a chance to take precedence + // This allows a developer to cause a 3xx, 4xx, or 5xx mocked response by passing + // the appropriate example header value. + mt, lo, noMT = rme.findMediaTypeContainingNamedExample(operation, request, preferred) + } + + if (noMT) { + // When no preferred header is passed, or preferred header did not match a named example + lo = rme.findLowestSuccessCode(operation) + mt, noMT = rme.findBestMediaTypeMatch(operation, request, []string{lo}) + } + if mt == nil && noMT { mtString := rme.extractMediaTypeHeader(request) return rme.buildError( @@ -312,7 +326,7 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, ), 415, nil } - mock, mockErr := rme.mockEngine.GenerateMock(mt, rme.extractPreferred(request)) + mock, mockErr := rme.mockEngine.GenerateMock(mt, preferred) if mockErr != nil { return rme.buildError( 422, @@ -326,6 +340,37 @@ func (rme *ResponseMockEngine) runWorkflow(request *http.Request) ([]byte, int, return mock, c, nil } +func (rme *ResponseMockEngine) findMediaTypeContainingNamedExample( + operation *v3.Operation, + request *http.Request, + preferredExample string) (*v3.MediaType, string, bool) { + + mediaTypeString := rme.extractMediaTypeHeader(request) + + for codePairs := operation.Responses.Codes.First(); codePairs != nil; codePairs = codePairs.Next() { + resp := codePairs.Value() + + if resp.Content != nil { + responseBody := resp.Content.GetOrZero(mediaTypeString) + if responseBody == nil { + responseBody = resp.Content.GetOrZero("application/json") + } + + if responseBody == nil { + continue; + } + + _, present := responseBody.Examples.Get(preferredExample) + + if present { + return responseBody, codePairs.Key(), false + } + } + } + + return nil, "", true +} + func (rme *ResponseMockEngine) findLowestSuccessCode(operation *v3.Operation) string { var lowestCode = 299 @@ -341,14 +386,15 @@ func (rme *ResponseMockEngine) findLowestSuccessCode(operation *v3.Operation) st return fmt.Sprintf("%d", lowestCode) } -func (rme *ResponseMockEngine) lookForResponseCodes( +func (rme *ResponseMockEngine) findBestMediaTypeMatch( op *v3.Operation, request *http.Request, resultCodes []string) (*v3.MediaType, bool) { mediaTypeString := rme.extractMediaTypeHeader(request) - // check if the media type exists in the response. + // Try to find a matching media type in responses matching + // parameterized result codes for _, code := range resultCodes { resp := op.Responses.Codes.GetOrZero(code) @@ -370,6 +416,8 @@ func (rme *ResponseMockEngine) lookForResponseCodes( } } + // As a last resort, check if a default response is specified and attempt + // to use that if op.Responses.Default != nil && op.Responses.Default.Content != nil { if op.Responses.Default.Content.GetOrZero(mediaTypeString) != nil { return op.Responses.Default.Content.GetOrZero(mediaTypeString), false diff --git a/mock/mock_engine_test.go b/mock/mock_engine_test.go index 7641f34..41bf3f9 100644 --- a/mock/mock_engine_test.go +++ b/mock/mock_engine_test.go @@ -932,3 +932,183 @@ components: assert.NotEmpty(t, decoded["description"]) } + +func TestNewMockEngine_UseExamples_Preferred_From_400(t *testing.T) { + + spec := `openapi: 3.1.0 +paths: + /test: + get: + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Thing' + examples: + happyDays: + value: + name: happy days + description: a terrible show from a time that never existed. + robocop: + value: + name: robocop + description: perhaps the best cyberpunk movie ever made. + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorThing' + examples: + sadErrorDays: + value: + name: sad error days + description: a sad error prone show + sadcop: + value: + name: sad cop + description: perhaps the saddest cyberpunk movie ever made. +components: + schemas: + Thing: + type: object + properties: + name: + type: string + example: nameExample + description: + type: string + example: descriptionExample + ErrorThing: + type: object + properties: + name: + type: string + example: errorNameExample + description: + type: string + example: errorDescriptionExample +` + + d, _ := libopenapi.NewDocument([]byte(spec)) + doc, _ := d.BuildV3Model() + + me := NewMockEngine(&doc.Model, false) + + request, _ := http.NewRequest(http.MethodGet, "https://api.pb33f.io/test", nil) + request.Header.Set(helpers.Preferred, "sadcop") + + b, status, err := me.GenerateResponse(request) + + assert.NoError(t, err) + assert.Equal(t, 400, status) + + var decoded map[string]any + _ = json.Unmarshal(b, &decoded) + + assert.Equal(t, "sad cop", decoded["name"]) + assert.Equal(t, "perhaps the saddest cyberpunk movie ever made.", decoded["description"]) +} + +func TestNewMockEngine_UseExamples_Preferred_200_Not_Json(t *testing.T) { +// A little far-fetched for an API to behave this way, +// where lowest 2xx response is html and second is json, +// including the test case just in case + spec := `openapi: 3.1.0 +paths: + /test: + get: + responses: + '200': + content: + text/html: + schema: + $ref: '#/components/schemas/HtmlThing' + examples: + happyHtmlDays: + value: