Commit 42891b02 authored by Matthew Sykes's avatar Matthew Sykes
Browse files

[FAB-12861] fix client cert middleware logic



Client certificate is required for logging and metrics when TLS enabled.
Client certificate is not required for healthz when TLS enabled.
Client certificate always required when RequireClientCert is true

Change-Id: I462647f0907efc2aaae9f1e9e6d2d3207da3ee57
Signed-off-by: default avatarMatthew Sykes <sykesmat@us.ibm.com>
parent d3360299
......@@ -49,24 +49,27 @@ func generateCertificates(tempDir string) {
Expect(err).NotTo(HaveOccurred())
}
func newHTTPClient(tlsDir string) *http.Client {
clientCert, err := tls.LoadX509KeyPair(
filepath.Join(tlsDir, "client-cert.pem"),
filepath.Join(tlsDir, "client-key.pem"),
)
Expect(err).NotTo(HaveOccurred())
func newHTTPClient(tlsDir string, withClientCert bool) *http.Client {
clientCertPool := x509.NewCertPool()
caCert, err := ioutil.ReadFile(filepath.Join(tlsDir, "server-ca.pem"))
Expect(err).NotTo(HaveOccurred())
clientCertPool.AppendCertsFromPEM(caCert)
tlsClientConfig := &tls.Config{
RootCAs: clientCertPool,
}
if withClientCert {
clientCert, err := tls.LoadX509KeyPair(
filepath.Join(tlsDir, "client-cert.pem"),
filepath.Join(tlsDir, "client-key.pem"),
)
Expect(err).NotTo(HaveOccurred())
tlsClientConfig.Certificates = []tls.Certificate{clientCert}
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: clientCertPool,
},
TLSClientConfig: tlsClientConfig,
},
}
}
......
......@@ -176,8 +176,7 @@ func (s *System) initializeMetricsProvider() error {
case "prometheus":
s.Provider = &prometheus.Provider{}
secure := s.options.TLS.Enabled && s.options.TLS.ClientCertRequired
s.mux.Handle(m.Prometheus.HandlerPath, s.handlerChain(prom.Handler(), secure))
s.mux.Handle(m.Prometheus.HandlerPath, s.handlerChain(prom.Handler(), s.options.TLS.Enabled))
return nil
default:
......@@ -191,8 +190,7 @@ func (s *System) initializeMetricsProvider() error {
}
func (s *System) initializeLoggingHandler() {
secure := s.options.TLS.Enabled && s.options.TLS.ClientCertRequired
s.mux.Handle("/logspec", s.handlerChain(httpadmin.NewSpecHandler(), secure))
s.mux.Handle("/logspec", s.handlerChain(httpadmin.NewSpecHandler(), s.options.TLS.Enabled))
}
func (s *System) initializeHealthCheckHandler() {
......
......@@ -36,9 +36,10 @@ var _ = Describe("System", func() {
fakeLogger *fakes.Logger
tempDir string
client *http.Client
options operations.Options
system *operations.System
client *http.Client
unauthClient *http.Client
options operations.Options
system *operations.System
)
BeforeEach(func() {
......@@ -47,7 +48,8 @@ var _ = Describe("System", func() {
Expect(err).NotTo(HaveOccurred())
generateCertificates(tempDir)
client = newHTTPClient(tempDir)
client = newHTTPClient(tempDir, true)
unauthClient = newHTTPClient(tempDir, false)
fakeLogger = &fakes.Logger{}
options = operations.Options{
......@@ -60,7 +62,7 @@ var _ = Describe("System", func() {
Enabled: true,
CertFile: filepath.Join(tempDir, "server-cert.pem"),
KeyFile: filepath.Join(tempDir, "server-key.pem"),
ClientCertRequired: true,
ClientCertRequired: false,
ClientCACertFiles: []string{filepath.Join(tempDir, "client-ca.pem")},
},
}
......@@ -79,10 +81,15 @@ var _ = Describe("System", func() {
err := system.Start()
Expect(err).NotTo(HaveOccurred())
resp, err := client.Get(fmt.Sprintf("https://%s/logspec", system.Addr()))
logspecURL := fmt.Sprintf("https://%s/logspec", system.Addr())
resp, err := client.Get(logspecURL)
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
resp.Body.Close()
resp, err = unauthClient.Get(logspecURL)
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
Context("when TLS is disabled", func() {
......@@ -102,6 +109,22 @@ var _ = Describe("System", func() {
})
})
Context("when ClientCertRequired is true", func() {
BeforeEach(func() {
options.TLS.ClientCertRequired = true
system = operations.NewSystem(options)
})
It("requires a client cert to connect", func() {
err := system.Start()
Expect(err).NotTo(HaveOccurred())
_, err = unauthClient.Get(fmt.Sprintf("https://%s/healthz", system.Addr()))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("remote error: tls: bad certificate"))
})
})
Context("when listen fails", func() {
var listener net.Listener
......@@ -223,13 +246,18 @@ var _ = Describe("System", func() {
err := system.Start()
Expect(err).NotTo(HaveOccurred())
resp, err := client.Get(fmt.Sprintf("https://%s/metrics", system.Addr()))
metricsURL := fmt.Sprintf("https://%s/metrics", system.Addr())
resp, err := client.Get(metricsURL)
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
Expect(body).To(ContainSubstring("# TYPE go_gc_duration_seconds summary"))
resp, err = unauthClient.Get(metricsURL)
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
})
})
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment