From 520a6b05a18cf8add215ecca1d66dfead82b15ea Mon Sep 17 00:00:00 2001
From: Alex Ellis <alexellis2@gmail.com>
Date: Fri, 20 Oct 2017 19:01:45 +0200
Subject: [PATCH] Commit unit tests for content-type order

Signed-off-by: Alex Ellis <alexellis2@gmail.com>
---
 gateway/handlers/proxy.go   | 24 +++++++++++--------
 gateway/tests/proxy_test.go | 46 +++++++++++++++++++++++++++++++++++++
 2 files changed, 61 insertions(+), 9 deletions(-)
 create mode 100644 gateway/tests/proxy_test.go

diff --git a/gateway/handlers/proxy.go b/gateway/handlers/proxy.go
index 78c1069b..ee968c75 100644
--- a/gateway/handlers/proxy.go
+++ b/gateway/handlers/proxy.go
@@ -168,23 +168,29 @@ func invokeService(w http.ResponseWriter, r *http.Request, metrics metrics.Metri
 	clientHeader := w.Header()
 	copyHeaders(&clientHeader, &response.Header)
 
-	responseHeader := response.Header.Get("Content-Type")
-	requestHeader := r.Header.Get("Content-Type")
 	defaultHeader := "text/plain"
-	contentTypeField := "Content-Type"
 
-	fmt.Printf("Req %s Res %s\n", requestHeader, responseHeader)
+	w.Header().Set("Content-Type", GetContentType(response.Header, r.Header, defaultHeader))
 
+	writeHead(service, metrics, response.StatusCode, w)
+	w.Write(responseBody)
+}
+
+// GetContentType resolves the correct Content-Tyoe for a proxied function
+func GetContentType(request http.Header, proxyResponse http.Header, defaultValue string) string {
+	responseHeader := proxyResponse.Get("Content-Type")
+	requestHeader := request.Get("Content-Type")
+
+	var headerContentType string
 	if len(responseHeader) > 0 {
-		w.Header().Set(contentTypeField, responseHeader)
+		headerContentType = responseHeader
 	} else if len(requestHeader) > 0 {
-		w.Header().Set(contentTypeField, requestHeader)
+		headerContentType = requestHeader
 	} else {
-		w.Header().Set(contentTypeField, defaultHeader)
+		headerContentType = defaultValue
 	}
 
-	writeHead(service, metrics, response.StatusCode, w)
-	w.Write(responseBody)
+	return headerContentType
 }
 
 func copyHeaders(destination *http.Header, source *http.Header) {
diff --git a/gateway/tests/proxy_test.go b/gateway/tests/proxy_test.go
new file mode 100644
index 00000000..deba790b
--- /dev/null
+++ b/gateway/tests/proxy_test.go
@@ -0,0 +1,46 @@
+package tests
+
+import (
+	"net/http"
+	"testing"
+
+	"github.com/openfaas/faas/gateway/handlers"
+)
+
+func Test_GetContentType_UsesResponseValue(t *testing.T) {
+	request := http.Header{}
+	request.Add("Content-Type", "text/plain")
+	response := http.Header{}
+	response.Add("Content-Type", "text/html")
+
+	contentType := handlers.GetContentType(request, response, "default")
+	if contentType != response.Get("Content-Type") {
+		t.Errorf("Got: %s, want: %s", contentType, response.Get("Content-Type"))
+	}
+}
+
+func Test_GetContentType_UsesRequest_WhenResponseEmpty(t *testing.T) {
+	request := http.Header{}
+	request.Add("Content-Type", "text/plain")
+	response := http.Header{}
+	response.Add("Content-Type", "")
+
+	contentType := handlers.GetContentType(request, response, "default")
+	if contentType != request.Get("Content-Type") {
+		t.Errorf("Got: %s, want: %s", contentType, request.Get("Content-Type"))
+	}
+
+}
+
+func Test_GetContentType_UsesDefaultWhenRequestResponseEmpty(t *testing.T) {
+	request := http.Header{}
+	request.Add("Content-Type", "")
+	response := http.Header{}
+	response.Add("Content-Type", "")
+
+	contentType := handlers.GetContentType(request, response, "default")
+	if contentType != "default" {
+		t.Errorf("Got: %s, want: %s", contentType, "default")
+	}
+
+}
-- 
GitLab