Skip to content

Commit

Permalink
vmproxy support for custom Host header
Browse files Browse the repository at this point in the history
fixes #3626
  • Loading branch information
r3code committed Feb 12, 2025
1 parent 8b9432a commit c6ae4a0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 37 deletions.
2 changes: 2 additions & 0 deletions vmproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type flags struct {
ListenPort int `default:"1280" help:"Listen port for proxy"`
ListenAddress string `default:"127.0.0.1" help:"Listen address for proxy"`
HeaderName string `default:"X-Proxy-Filter" help:"Header name to read filter configuration from. The content of the header shall be a base64 encoded JSON array with strings. Each string is a filter. Multiple filters are joined with a logical OR."` //nolint:lll
HostHeader string `default:"" help:"Optional Host header value to set in the request, overrides existing"`
}

func main() {
Expand Down Expand Up @@ -67,6 +68,7 @@ func runProxy(opts flags, proxyFn func(cfg proxy.Config) error) error {
HeaderName: opts.HeaderName,
ListenAddress: net.JoinHostPort(opts.ListenAddress, strconv.Itoa(opts.ListenPort)),
TargetURL: opts.TargetURL,
HostHeader: opts.HostHeader,
})

if !errors.Is(err, http.ErrServerClosed) {
Expand Down
74 changes: 43 additions & 31 deletions vmproxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type Config struct {
ListenAddress string
// Target URL to forward requests to
TargetURL *url.URL
// Optional Host header value to set in the request
HostHeader string
}

// RunProxy starts proxy which adds extra filters based on configuration.
Expand All @@ -51,7 +53,7 @@ func RunProxy(cfg Config) error {

func getHandler(cfg Config) http.HandlerFunc {
rProxy := &httputil.ReverseProxy{
Director: director(cfg.TargetURL, cfg.HeaderName),
Director: director(cfg.TargetURL, cfg.HeaderName, strings.TrimSpace(cfg.HostHeader)),
}

return func(rw http.ResponseWriter, req *http.Request) {
Expand All @@ -78,47 +80,57 @@ func failOnInvalidHeader(rw http.ResponseWriter, req *http.Request, headerName s
return false
}

func director(target *url.URL, headerName string) func(*http.Request) {
return func(req *http.Request) {
now := time.Now()
func prepareRequest(req *http.Request, target *url.URL, headerName string, hostHeader string) {
now := time.Now()

req.URL.Scheme = target.Scheme
req.URL.Host = target.Host

// Update or add Host header if hostHeader is provided
if hostHeader != "" {
req.Host = hostHeader
req.Header.Set("Host", hostHeader)
}

req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
rp, err := target.Parse(strings.TrimPrefix(req.URL.Path, "/"))
if err != nil {
logrus.Error(err)
}
req.URL.Path = rp.Path

// Replace extra filters if present
if filters := req.Header.Get(headerName); filters != "" {
q := req.URL.Query()
q.Del("extra_filters[]")

rp, err := target.Parse(strings.TrimPrefix(req.URL.Path, "/"))
parsed, err := parseFilters(filters)
if err != nil {
logrus.Error(err)
}
req.URL.Path = rp.Path

// Replace extra filters if present
if filters := req.Header.Get(headerName); filters != "" {
q := req.URL.Query()
q.Del("extra_filters[]")

parsed, err := parseFilters(filters)
if err != nil {
logrus.Error(err)
}
for _, f := range parsed {
q.Add("extra_filters[]", f)
}

for _, f := range parsed {
q.Add("extra_filters[]", f)
}
req.URL.RawQuery = q.Encode()

req.URL.RawQuery = q.Encode()
logrus.Debugf(
"Received filters: %s, Parsed filters: %#v, Query: %s, Target URL: %s, Time spent: %s",
filters, parsed, req.URL.RawQuery, req.URL, time.Since(now))
}

logrus.Debugf(
"Received filters: %s, Parsed filters: %#v, Query: %s, Target URL: %s, Time spent: %s",
filters, parsed, req.URL.RawQuery, req.URL, time.Since(now))
}
// Do not trust the client
req.Header.Del("X-Forwarded-For")

// Do not trust the client
req.Header.Del("X-Forwarded-For")
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}

if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
func director(target *url.URL, headerName string, hostHeader string) func(*http.Request) {
return func(req *http.Request) {
prepareRequest(req, target, headerName, hostHeader)
}
}

Expand Down
35 changes: 29 additions & 6 deletions vmproxy/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const (
func TestProxy(t *testing.T) {
t.Parallel()

setup := func(t *testing.T, filters []string) http.HandlerFunc {
setup := func(t *testing.T, filters []string, hostHeader string) http.HandlerFunc {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if filters != nil {
Expand All @@ -50,18 +50,23 @@ func TestProxy(t *testing.T) {
handler := getHandler(Config{
HeaderName: headerName,
TargetURL: testURL,
HostHeader: hostHeader,
})

return handler
}

handler := setup(t, nil)
handler := setup(t, nil, "")

t.Run("shall proxy request", func(t *testing.T) {
t.Parallel()

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, targetURL, nil)
hostHeader := ""
url, err := url.Parse(targetURL)
require.NoError(t, err)
prepareRequest(req, url, headerName, hostHeader)

handler.ServeHTTP(rec, req)
resp := rec.Result()
Expand Down Expand Up @@ -125,15 +130,20 @@ func TestProxy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

url := targetURL
url, err := url.Parse(targetURL)
require.NoError(t, err)

if tc.targetURL != "" {
url = tc.targetURL
url, err = url.Parse(tc.targetURL)
require.NoError(t, err)
}

handler := setup(t, tc.expectedFilters)
handler := setup(t, tc.expectedFilters, "")

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, url, nil)
req := httptest.NewRequest(http.MethodGet, targetURL, nil)
hostHeader := ""
prepareRequest(req, url, headerName, hostHeader)
req.Header.Set(headerName, tc.headerContent)

handler.ServeHTTP(rec, req)
Expand All @@ -144,4 +154,17 @@ func TestProxy(t *testing.T) {
})
}
})

t.Run("prepareRequest: set manual Host header", func(t *testing.T) {
t.Parallel()

hostHeader := "test.example.org"
req := httptest.NewRequest(http.MethodGet, targetURL, nil)
url, err := url.Parse(targetURL)
require.NoError(t, err)

prepareRequest(req, url, headerName, hostHeader)

require.Equal(t, hostHeader, req.Header["Host"][0])
})
}

0 comments on commit c6ae4a0

Please sign in to comment.