Commit 5be77553 authored by Jason Yellick's avatar Jason Yellick Committed by Gerrit Code Review
Browse files

Merge "Allow statically configured root CAs for TLS" into release-1.4

parents 160a228c 86f1c990
......@@ -59,30 +59,6 @@ func GetCredentialSupport() *CredentialSupport {
return credSupport
}
// GetServerRootCAs returns the PEM-encoded root certificates for all of the
// application and orderer organizations defined for all chains. The root
// certificates returned should be used to set the trusted server roots for
// TLS clients.
func (cas *CASupport) GetServerRootCAs() (appRootCAs, ordererRootCAs [][]byte) {
cas.RLock()
defer cas.RUnlock()
appRootCAs = [][]byte{}
ordererRootCAs = [][]byte{}
for _, appRootCA := range cas.AppRootCAsByChain {
appRootCAs = append(appRootCAs, appRootCA...)
}
for _, ordererRootCA := range cas.OrdererRootCAsByChain {
ordererRootCAs = append(ordererRootCAs, ordererRootCA...)
}
// also need to append statically configured root certs
appRootCAs = append(appRootCAs, cas.ServerRootCAs...)
return appRootCAs, ordererRootCAs
}
// GetClientRootCAs returns the PEM-encoded root certificates for all of the
// application and orderer organizations defined for all chains. The root
// certificates returned should be used to set the trusted client roots for
......@@ -118,10 +94,14 @@ func (cs *CredentialSupport) GetClientCertificate() tls.Certificate {
return cs.clientCert
}
// GetDeliverServiceCredentials returns GRPC transport credentials for given channel to be used by GRPC
// clients which communicate with ordering service endpoints.
// If the channel isn't found, error is returned.
func (cs *CredentialSupport) GetDeliverServiceCredentials(channelID string) (credentials.TransportCredentials, error) {
// GetDeliverServiceCredentials returns gRPC transport credentials for given channel
// to be used by gRPC clients which communicate with ordering service endpoints.
// If appendStaticRoots is set to true, ServerRootCAs are also included in the
// credentials. If the channel isn't found, an error is returned.
func (cs *CredentialSupport) GetDeliverServiceCredentials(
channelID string,
appendStaticRoots bool,
) (credentials.TransportCredentials, error) {
cs.RLock()
defer cs.RUnlock()
......@@ -144,36 +124,58 @@ func (cs *CredentialSupport) GetDeliverServiceCredentials(channelID string) (cre
if err == nil {
certPool.AddCert(cert)
} else {
commLogger.Warningf("Failed to add root cert to credentials (%s)", err)
commLogger.Warningf("Failed to add root cert to credentials: %s", err)
}
} else {
commLogger.Warning("Failed to add root cert to credentials")
}
}
if appendStaticRoots {
for _, cert := range cs.ServerRootCAs {
block, _ := pem.Decode(cert)
if block != nil {
cert, err := x509.ParseCertificate(block.Bytes)
if err == nil {
certPool.AddCert(cert)
} else {
commLogger.Warningf("Failed to add root cert to credentials: %s", err)
}
} else {
commLogger.Warning("Failed to add root cert to credentials")
}
}
}
tlsConfig.RootCAs = certPool
creds = credentials.NewTLS(tlsConfig)
return creds, nil
}
// GetPeerCredentials returns GRPC transport credentials for use by GRPC
// GetPeerCredentials returns gRPC transport credentials for use by gRPC
// clients which communicate with remote peer endpoints.
func (cs *CredentialSupport) GetPeerCredentials() credentials.TransportCredentials {
var creds credentials.TransportCredentials
cs.RLock()
defer cs.RUnlock()
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cs.clientCert},
}
certPool := x509.NewCertPool()
// loop through the server root CAs
roots, _ := cs.GetServerRootCAs()
for _, root := range roots {
err := AddPemToCertPool(root, certPool)
appRootCAs := [][]byte{}
for _, appRootCA := range cs.AppRootCAsByChain {
appRootCAs = append(appRootCAs, appRootCA...)
}
// also need to append statically configured root certs
appRootCAs = append(appRootCAs, cs.ServerRootCAs...)
// loop through the app root CAs
for _, appRootCA := range appRootCAs {
err := AddPemToCertPool(appRootCA, certPool)
if err != nil {
commLogger.Warningf("Failed adding certificates to peer's client TLS trust pool: %s", err)
}
}
tlsConfig.RootCAs = certPool
creds = credentials.NewTLS(tlsConfig)
return creds
return credentials.NewTLS(tlsConfig)
}
func getEnv(key, def string) string {
......
......@@ -173,12 +173,6 @@ func TestCASupport(t *testing.T) {
cas.ServerRootCAs = [][]byte{rootCAs[5]}
cas.ClientRootCAs = [][]byte{rootCAs[5]}
appServerRoots, ordererServerRoots := cas.GetServerRootCAs()
t.Logf("%d appServerRoots | %d ordererServerRoots", len(appServerRoots),
len(ordererServerRoots))
assert.Equal(t, 4, len(appServerRoots), "Expected 4 app server root CAs")
assert.Equal(t, 2, len(ordererServerRoots), "Expected 2 orderer server root CAs")
appClientRoots, ordererClientRoots := cas.GetClientRootCAs()
t.Logf("%d appClientRoots | %d ordererClientRoots", len(appClientRoots),
len(ordererClientRoots))
......@@ -213,19 +207,13 @@ func TestCredentialSupport(t *testing.T) {
cs.ServerRootCAs = [][]byte{rootCAs[5]}
cs.ClientRootCAs = [][]byte{rootCAs[5]}
appServerRoots, ordererServerRoots := cs.GetServerRootCAs()
t.Logf("%d appServerRoots | %d ordererServerRoots", len(appServerRoots),
len(ordererServerRoots))
assert.Equal(t, 4, len(appServerRoots), "Expected 4 app server root CAs")
assert.Equal(t, 2, len(ordererServerRoots), "Expected 2 orderer server root CAs")
appClientRoots, ordererClientRoots := cs.GetClientRootCAs()
t.Logf("%d appClientRoots | %d ordererClientRoots", len(appClientRoots),
len(ordererClientRoots))
assert.Equal(t, 4, len(appClientRoots), "Expected 4 app client root CAs")
assert.Equal(t, 2, len(ordererClientRoots), "Expected 4 orderer client root CAs")
creds, _ := cs.GetDeliverServiceCredentials("channel1")
creds, _ := cs.GetDeliverServiceCredentials("channel1", false)
assert.Equal(t, "1.2", creds.Info().SecurityVersion,
"Expected Security version to be 1.2")
creds = cs.GetPeerCredentials()
......@@ -235,7 +223,7 @@ func TestCredentialSupport(t *testing.T) {
// append some bad certs and make sure things still work
cs.ServerRootCAs = append(cs.ServerRootCAs, []byte("badcert"))
cs.ServerRootCAs = append(cs.ServerRootCAs, []byte(badPEM))
creds, _ = cs.GetDeliverServiceCredentials("channel1")
creds, _ = cs.GetDeliverServiceCredentials("channel1", false)
assert.Equal(t, "1.2", creds.Info().SecurityVersion,
"Expected Security version to be 1.2")
creds = cs.GetPeerCredentials()
......@@ -330,16 +318,18 @@ func TestImpersonation(t *testing.T) {
OrdererRootCAsByChain: make(map[string][][]byte),
},
}
_, err := cs.GetDeliverServiceCredentials("C")
_, err := cs.GetDeliverServiceCredentials("C", false)
assert.Error(t, err)
cs.OrdererRootCAsByChain["A"] = [][]byte{osA.caCert}
cs.OrdererRootCAsByChain["B"] = [][]byte{osB.caCert}
cs.ServerRootCAs = append(cs.ServerRootCAs, osB.caCert)
testInvoke(t, "A", osA, cs, true)
testInvoke(t, "B", osB, cs, true)
testInvoke(t, "A", osB, cs, false)
testInvoke(t, "B", osA, cs, false)
testInvoke(t, "A", osA, cs, false, true)
testInvoke(t, "B", osB, cs, false, true)
testInvoke(t, "A", osB, cs, false, false)
testInvoke(t, "B", osA, cs, false, false)
testInvoke(t, "B", osA, cs, true, false)
}
......@@ -348,9 +338,10 @@ func testInvoke(
channelID string,
s *srv,
cs *CredentialSupport,
staticRoots bool,
shouldSucceed bool) {
creds, err := cs.GetDeliverServiceCredentials(channelID)
creds, err := cs.GetDeliverServiceCredentials(channelID, staticRoots)
assert.NoError(t, err)
endpoint := s.address
......
......@@ -44,6 +44,10 @@ func getReConnectBackoffThreshold() float64 {
return util.GetFloat64OrDefault("peer.deliveryclient.reConnectBackoffThreshold", defaultReConnectBackoffThreshold)
}
func staticRootsEnabled() bool {
return viper.GetBool("peer.deliveryclient.staticRootsEnabled")
}
// DeliverService used to communicate with orderers to obtain
// new blocks and send them to the committer service
type DeliverService interface {
......@@ -251,7 +255,7 @@ func DefaultConnectionFactory(channelID string) func(endpoint string) (*grpc.Cli
dialOpts = append(dialOpts, comm.ClientKeepaliveOptions(kaOpts)...)
if viper.GetBool("peer.tls.enabled") {
creds, err := comm.GetCredentialSupport().GetDeliverServiceCredentials(channelID)
creds, err := comm.GetCredentialSupport().GetDeliverServiceCredentials(channelID, staticRootsEnabled())
if err != nil {
return nil, fmt.Errorf("failed obtaining credentials for channel %s: %v", channelID, err)
}
......
......@@ -174,6 +174,30 @@ func GetServerConfig() (comm.ServerConfig, error) {
return serverConfig, nil
}
// GetServerRootCAs returns the root certificates which will be trusted for
// gRPC client connections to peers and orderers.
func GetServerRootCAs() ([][]byte, error) {
var rootCAs [][]byte
if config.GetPath("peer.tls.rootcert.file") != "" {
rootCert, err := ioutil.ReadFile(config.GetPath("peer.tls.rootcert.file"))
if err != nil {
return nil, fmt.Errorf("error loading TLS root certificate (%s)", err)
}
rootCAs = append(rootCAs, rootCert)
}
for _, file := range viper.GetStringSlice("peer.tls.serverRootCAs.files") {
rootCert, err := ioutil.ReadFile(
config.TranslatePath(filepath.Dir(viper.ConfigFileUsed()), file))
if err != nil {
return nil,
fmt.Errorf("error loading server root CAs: %s", err)
}
rootCAs = append(rootCAs, rootCert)
}
return rootCAs, nil
}
// GetClientCertificate returns the TLS certificate to use for gRPC client
// connections
func GetClientCertificate() (tls.Certificate, error) {
......
......@@ -178,6 +178,74 @@ func TestGetServerConfig(t *testing.T) {
}
func TestGetServerRootCAs(t *testing.T) {
var tests = []struct {
name string
rootCert string
serverRootCAs []string
count int
shouldFail bool
}{
{
name: "no roots",
rootCert: "",
serverRootCAs: []string{},
count: 0,
},
{
name: "rootCert only",
rootCert: filepath.Join("testdata", "Org1-cert.pem"),
serverRootCAs: []string{},
count: 1,
},
{
name: "serverRootCAs only",
rootCert: "",
serverRootCAs: []string{
filepath.Join("testdata", "Org2-cert.pem"),
filepath.Join("testdata", "Org3-cert.pem"),
},
count: 2,
},
{
name: "rootCert and serverRootCAs",
rootCert: filepath.Join("testdata", "Org1-cert.pem"),
serverRootCAs: []string{
filepath.Join("testdata", "Org2-cert.pem"),
filepath.Join("testdata", "Org3-cert.pem"),
},
count: 3,
},
{
name: "bad rootCert",
rootCert: filepath.Join("testdata", "Org11-cert.pem"),
serverRootCAs: []string{},
shouldFail: true,
},
{
name: "bad serverRootCAs",
rootCert: "",
serverRootCAs: []string{filepath.Join("testdata", "Org11-cert.pem")},
shouldFail: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
viper.Set("peer.tls.rootcert.file", test.rootCert)
viper.Set("peer.tls.serverRootCAs.files", test.serverRootCAs)
roots, err := GetServerRootCAs()
if test.shouldFail {
assert.Error(t, err, "Expected an error")
} else {
assert.NoError(t, err, "Error should not have occurred")
assert.Equal(t, test.count, len(roots))
}
})
}
}
func TestGetClientCertificate(t *testing.T) {
viper.Set("peer.tls.key.file", "")
viper.Set("peer.tls.cert.file", "")
......
......@@ -239,12 +239,16 @@ func serve(args []string) error {
logger.Info("Starting peer with TLS enabled")
// set up credential support
cs := comm.GetCredentialSupport()
cs.ServerRootCAs = serverConfig.SecOpts.ServerRootCAs
roots, err := peer.GetServerRootCAs()
if err != nil {
logger.Fatalf("Failed to set TLS server root CAs: %s", err)
}
cs.ServerRootCAs = roots
// set the cert to use if client auth is requested by remote endpoints
clientCert, err := peer.GetClientCertificate()
if err != nil {
logger.Fatalf("Failed to set TLS client certificate (%s)", err)
logger.Fatalf("Failed to set TLS client certificate: %s", err)
}
comm.GetCredentialSupport().SetClientCertificate(clientCert)
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment