diff --git a/internal/server/control/control.go b/internal/server/control/control.go index 0437901..c7a5e7f 100644 --- a/internal/server/control/control.go +++ b/internal/server/control/control.go @@ -70,6 +70,10 @@ func listRoutes(res http.ResponseWriter, req *http.Request) { var funcAdd func(model.Route) model.Route = user.Routes.Append var idGenerator = uuid.NewUUID +var pathValidator func(string) error = func(path string) error { + return mux.NewRouter().NewRoute().BuildOnly().Path(path).GetError() +} + func addRoute(res http.ResponseWriter, req *http.Request) { var route model.Route @@ -88,6 +92,12 @@ func addRoute(res http.ResponseWriter, req *http.Request) { return } + err = pathValidator(route.Pattern) + if err != nil { + res.WriteHeader(http.StatusUnprocessableEntity) + return + } + id, err := idGenerator() if err != nil { res.WriteHeader(http.StatusInternalServerError) diff --git a/internal/server/control/control_test.go b/internal/server/control/control_test.go index b7b91e6..6a918c2 100644 --- a/internal/server/control/control_test.go +++ b/internal/server/control/control_test.go @@ -72,6 +72,22 @@ func TestConfigRouterHasRoutesWellConfigured(t *testing.T) { } } +func TestPathValidatorNoErrorWhenCorrectPath(t *testing.T) { + err := pathValidator("/routes/{routeID}") + + if err != nil { + t.Error(err) + } +} + +func TestPathValidatorErrorWhenInvalidPath(t *testing.T) { + err := pathValidator("/routes/{routeID{") + + if err == nil { + t.FailNow() + } +} + func TestAddRouteReturnsBadRequestWhenMalformedJSONBody(t *testing.T) { reqPayload := `{ method": "GET", @@ -181,6 +197,9 @@ func TestAddRouteGeneratesRouteID(t *testing.T) { input.Index = 0 return input } + origPathValidator := pathValidator + defer func() { pathValidator = origPathValidator }() + pathValidator = func(path string) error { return nil } handler.ServeHTTP(resp, req) @@ -200,6 +219,10 @@ func TestAddRoute500sWhenIDGeneratorFails(t *testing.T) { resp := httptest.NewRecorder() handler := http.HandlerFunc(addRoute) + origPathValidator := pathValidator + defer func() { pathValidator = origPathValidator }() + pathValidator = func(path string) error { return nil } + idGenOrig := idGenerator defer func() { idGenerator = idGenOrig }() idGenerator = func() (uuid.UUID, error) { @@ -237,6 +260,9 @@ func TestAddRouteReturnsCreated(t *testing.T) { return model.Route{} } + origPathValidator := pathValidator + defer func() { pathValidator = origPathValidator }() + pathValidator = func(path string) error { return nil } handler.ServeHTTP(resp, req) @@ -259,6 +285,27 @@ func TestAddRouteReturnsCreated(t *testing.T) { } } +func TestAddRoute422sWhenInvalidRoute(t *testing.T) { + reqPayload := `{ + "method": "GET", + "url_pattern": "/he{{o", + "entrypoint": "/bin/sh -c", + "command": "echo Hello World | kapow set /response/body" +}` + req := httptest.NewRequest(http.MethodPost, "/routes", strings.NewReader(reqPayload)) + resp := httptest.NewRecorder() + handler := http.HandlerFunc(addRoute) + origPathValidator := pathValidator + defer func() { pathValidator = origPathValidator }() + pathValidator = func(path string) error { return errors.New("Invalid route") } + + handler.ServeHTTP(resp, req) + + if resp.Code != http.StatusUnprocessableEntity { + t.Error("Invalid route registered") + } +} + func TestRemoveRouteReturnsNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/routes/ROUTE_XXXXXXXXXXXXXXXXXX", nil) resp := httptest.NewRecorder()