diff --git a/gateway/handlers/function_cache.go b/gateway/handlers/function_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..1128f8172681c673459d533f039c9c689b29801a --- /dev/null +++ b/gateway/handlers/function_cache.go @@ -0,0 +1,59 @@ +// Copyright (c) OpenFaaS Project. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +package handlers + +import ( + "sync" + "time" +) + +// FunctionMeta holds the last refresh and any other +// meta-data needed for caching. +type FunctionMeta struct { + LastRefresh time.Time + Replicas uint64 +} + +// Expired find out whether the cache item has expired with +// the given expiry duration from when it was stored. +func (fm *FunctionMeta) Expired(expiry time.Duration) bool { + return time.Now().After(fm.LastRefresh.Add(expiry)) +} + +// FunctionCache provides a cache of Function replica counts +type FunctionCache struct { + Cache map[string]*FunctionMeta + Expiry time.Duration + Sync sync.Mutex +} + +// Set replica count for functionName +func (fc *FunctionCache) Set(functionName string, replicas uint64) { + fc.Sync.Lock() + + if _, exists := fc.Cache[functionName]; !exists { + fc.Cache[functionName] = &FunctionMeta{} + } + + entry := fc.Cache[functionName] + entry.LastRefresh = time.Now() + entry.Replicas = replicas + + fc.Sync.Unlock() +} + +// Get replica count for functionName +func (fc *FunctionCache) Get(functionName string) (uint64, bool) { + replicas := uint64(0) + hit := false + fc.Sync.Lock() + + if val, exists := fc.Cache[functionName]; exists { + replicas = val.Replicas + hit = !val.Expired(fc.Expiry) + } + + fc.Sync.Unlock() + return replicas, hit +} diff --git a/gateway/handlers/function_cache_test.go b/gateway/handlers/function_cache_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c7100c3f1f4b7a78e9d9e8bf11e3d068189dd6fe --- /dev/null +++ b/gateway/handlers/function_cache_test.go @@ -0,0 +1,72 @@ +package handlers + +import ( + "testing" + "time" +) + +func Test_LastRefreshSet(t *testing.T) { + before := time.Now() + + fnName := "echo" + + cache := FunctionCache{ + Cache: make(map[string]*FunctionMeta), + Expiry: time.Millisecond * 1, + } + + if cache.Cache == nil { + t.Errorf("Expected cache map to be initialized") + t.Fail() + } + + cache.Set(fnName, 1) + + if _, exists := cache.Cache[fnName]; !exists { + t.Errorf("Expected entry to exist after setting %s", fnName) + t.Fail() + } + + if cache.Cache[fnName].LastRefresh.Before(before) { + t.Errorf("Expected LastRefresh for function to have been after start of test") + t.Fail() + } +} + +func Test_CacheExpiresIn1MS(t *testing.T) { + fnName := "echo" + + cache := FunctionCache{ + Cache: make(map[string]*FunctionMeta), + Expiry: time.Millisecond * 1, + } + + cache.Set(fnName, 1) + time.Sleep(time.Millisecond * 2) + + _, hit := cache.Get(fnName) + + wantHit := false + + if hit != wantHit { + t.Errorf("hit, want: %v, got %v", wantHit, hit) + } +} + +func Test_CacheGivesHitWithLongExpiry(t *testing.T) { + fnName := "echo" + + cache := FunctionCache{ + Cache: make(map[string]*FunctionMeta), + Expiry: time.Millisecond * 500, + } + + cache.Set(fnName, 1) + + _, hit := cache.Get(fnName) + wantHit := true + + if hit != wantHit { + t.Errorf("hit, want: %v, got %v", wantHit, hit) + } +} diff --git a/gateway/handlers/scaling.go b/gateway/handlers/scaling.go new file mode 100644 index 0000000000000000000000000000000000000000..adb29f54ced36ef8f48ebcdb4164139070e07678 --- /dev/null +++ b/gateway/handlers/scaling.go @@ -0,0 +1,147 @@ +// Copyright (c) OpenFaaS Project. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +package handlers + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "time" + + "github.com/openfaas/faas/gateway/requests" +) + +// ScalingConfig for scaling behaviours +type ScalingConfig struct { + MaxPollCount uint + FunctionPollInterval time.Duration + CacheExpiry time.Duration +} + +// MakeScalingHandler creates handler which can scale a function from +// zero to 1 replica(s). +func MakeScalingHandler(next http.HandlerFunc, upstream http.HandlerFunc, config ScalingConfig) http.HandlerFunc { + cache := FunctionCache{ + Cache: make(map[string]*FunctionMeta), + Expiry: config.CacheExpiry, + } + + return func(w http.ResponseWriter, r *http.Request) { + + functionName := getServiceName(r.URL.String()) + + if replicas, hit := cache.Get(functionName); hit && replicas > 0 { + next.ServeHTTP(w, r) + return + } + + replicas, code, err := getReplicas(functionName, upstream) + cache.Set(functionName, replicas) + + if err != nil { + var errStr string + if code == http.StatusNotFound { + errStr = fmt.Sprintf("unable to find function: %s", functionName) + + } else { + errStr = fmt.Sprintf("error finding function %s: %s", functionName, err.Error()) + } + + log.Printf(errStr) + w.WriteHeader(code) + w.Write([]byte(errStr)) + return + } + + if replicas == 0 { + minReplicas := uint64(1) + + err := scaleFunction(functionName, minReplicas, upstream) + if err != nil { + errStr := fmt.Errorf("unable to scale function [%s], err: %s", functionName, err) + log.Printf(errStr.Error()) + + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errStr.Error())) + return + } + + for i := 0; i < int(config.MaxPollCount); i++ { + replicas, _, err := getReplicas(functionName, upstream) + cache.Set(functionName, replicas) + + if err != nil { + errStr := fmt.Sprintf("error: %s", err.Error()) + log.Printf(errStr) + + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errStr)) + return + } + + if replicas > 0 { + break + } + + time.Sleep(config.FunctionPollInterval) + } + } + + next.ServeHTTP(w, r) + } +} + +func getReplicas(functionName string, upstream http.HandlerFunc) (uint64, int, error) { + + replicasQuery, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("/system/function/%s", functionName), nil) + rr := httptest.NewRecorder() + + upstream.ServeHTTP(rr, replicasQuery) + if rr.Code != 200 { + log.Printf("error, query replicas status: %d", rr.Code) + + var errBody string + if rr.Body != nil { + errBody = string(rr.Body.String()) + } + + return 0, rr.Code, fmt.Errorf("unable to query function: %s", string(errBody)) + } + + replicaBytes, _ := ioutil.ReadAll(rr.Body) + replicaResult := requests.Function{} + json.Unmarshal(replicaBytes, &replicaResult) + + return replicaResult.AvailableReplicas, rr.Code, nil +} + +func scaleFunction(functionName string, minReplicas uint64, upstream http.HandlerFunc) error { + scaleReq := ScaleServiceRequest{ + Replicas: minReplicas, + ServiceName: functionName, + } + + scaleBytesOut, _ := json.Marshal(scaleReq) + scaleBytesOutBody := bytes.NewBuffer(scaleBytesOut) + setReplicasReq, _ := http.NewRequest(http.MethodPost, fmt.Sprintf("/system/scale-function/%s", functionName), scaleBytesOutBody) + + rr := httptest.NewRecorder() + upstream.ServeHTTP(rr, setReplicasReq) + + if rr.Code != 200 { + return fmt.Errorf("scale to 1 replica status: %d", rr.Code) + } + + return nil +} + +// ScaleServiceRequest request to scale a function +type ScaleServiceRequest struct { + ServiceName string `json:"serviceName"` + Replicas uint64 `json:"replicas"` +} diff --git a/gateway/server.go b/gateway/server.go index 99707142a4572a3a03452d5d48a68f85ce6c7a08..62a6a0e29680564b2835a155718e2f5d40791d5e 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -113,10 +113,18 @@ func main() { } r := mux.NewRouter() + // max wait time to start a function = maxPollCount * functionPollInterval + scalingConfig := handlers.ScalingConfig{ + MaxPollCount: uint(1000), + FunctionPollInterval: time.Millisecond * 10, + CacheExpiry: time.Second * 5, // freshness of replica values before going stale + } + + scalingProxy := handlers.MakeScalingHandler(faasHandlers.Proxy, queryFunction, scalingConfig) // r.StrictSlash(false) // This didn't work, so register routes twice. - r.HandleFunc("/function/{name:[-a-zA-Z_0-9]+}", faasHandlers.Proxy) - r.HandleFunc("/function/{name:[-a-zA-Z_0-9]+}/", faasHandlers.Proxy) + r.HandleFunc("/function/{name:[-a-zA-Z_0-9]+}", scalingProxy) + r.HandleFunc("/function/{name:[-a-zA-Z_0-9]+}/", scalingProxy) r.HandleFunc("/system/info", handlers.MakeInfoHandler(handlers.MakeForwardingProxyHandler( reverseProxy, forwardingNotifiers, urlResolver))).Methods(http.MethodGet)