diff --git a/core/chaincode/handler.go b/core/chaincode/handler.go index a9204c671eb1d510161f5c9f52043cd5216187ca..b96c30de7f3be3f749cd1ef1fad795bd1d7b6ecd 100644 --- a/core/chaincode/handler.go +++ b/core/chaincode/handler.go @@ -47,12 +47,18 @@ type transactionContext struct { responseNotifier chan *pb.ChaincodeMessage // tracks open iterators used for range queries - queryIteratorMap map[string]commonledger.ResultsIterator + queryIteratorMap map[string]commonledger.ResultsIterator + pendingQueryResults map[string]*pendingQueryResult txsimulator ledger.TxSimulator historyQueryExecutor ledger.HistoryQueryExecutor } +type pendingQueryResult struct { + batch []*pb.QueryResultBytes + count int +} + type nextStateInfo struct { msg *pb.ChaincodeMessage sendToCC bool @@ -181,7 +187,8 @@ func (handler *Handler) createTxContext(ctxt context.Context, chainID string, tx } txctx := &transactionContext{chainID: chainID, signedProp: signedProp, proposal: prop, responseNotifier: make(chan *pb.ChaincodeMessage, 1), - queryIteratorMap: make(map[string]commonledger.ResultsIterator)} + queryIteratorMap: make(map[string]commonledger.ResultsIterator), + pendingQueryResults: make(map[string]*pendingQueryResult)} handler.txCtxs[txCtxID] = txctx txctx.txsimulator = getTxSimulator(ctxt) txctx.historyQueryExecutor = getHistoryQueryExecutor(ctxt) @@ -205,11 +212,12 @@ func (handler *Handler) deleteTxContext(chainID, txid string) { } } -func (handler *Handler) putQueryIterator(txContext *transactionContext, queryID string, +func (handler *Handler) initializeQueryContext(txContext *transactionContext, queryID string, queryIterator commonledger.ResultsIterator) { handler.Lock() defer handler.Unlock() txContext.queryIteratorMap[queryID] = queryIterator + txContext.pendingQueryResults[queryID] = &pendingQueryResult{batch: make([]*pb.QueryResultBytes, 0)} } func (handler *Handler) getQueryIterator(txContext *transactionContext, queryID string) commonledger.ResultsIterator { @@ -218,10 +226,12 @@ func (handler *Handler) getQueryIterator(txContext *transactionContext, queryID return txContext.queryIteratorMap[queryID] } -func (handler *Handler) deleteQueryIterator(txContext *transactionContext, queryID string) { +func (handler *Handler) cleanupQueryContext(txContext *transactionContext, queryID string) { handler.Lock() defer handler.Unlock() + txContext.queryIteratorMap[queryID].Close() delete(txContext.queryIteratorMap, queryID) + delete(txContext.pendingQueryResults, queryID) } // Check if the transactor is allow to call this chaincode on this channel @@ -710,8 +720,7 @@ func (handler *Handler) handleGetStateByRange(msg *pb.ChaincodeMessage) { errHandler := func(err error, iter commonledger.ResultsIterator, errFmt string, errArgs ...interface{}) { if iter != nil { - iter.Close() - handler.deleteQueryIterator(txContext, iterID) + handler.cleanupQueryContext(txContext, iterID) } payload := []byte(err.Error()) chaincodeLogger.Errorf(errFmt, errArgs...) @@ -730,7 +739,8 @@ func (handler *Handler) handleGetStateByRange(msg *pb.ChaincodeMessage) { return } - handler.putQueryIterator(txContext, iterID, rangeIter) + handler.initializeQueryContext(txContext, iterID, rangeIter) + var payload *pb.QueryResponse payload, err = getQueryResponse(handler, txContext, rangeIter, iterID) if err != nil { @@ -755,39 +765,52 @@ const maxResultLimit = 100 //getQueryResponse takes an iterator and fetch state to construct QueryResponse func getQueryResponse(handler *Handler, txContext *transactionContext, iter commonledger.ResultsIterator, iterID string) (*pb.QueryResponse, error) { - - var err error - var queryResult commonledger.QueryResult - var queryResultsBytes []*pb.QueryResultBytes - - for i := 0; i < maxResultLimit; i++ { - queryResult, err = iter.Next() - if err != nil { + pendingQueryResults := txContext.pendingQueryResults[iterID] + for { + queryResult, err := iter.Next() + switch { + case err != nil: chaincodeLogger.Errorf("Failed to get query result from iterator") - break - } - if queryResult == nil { - break - } - var resultBytes []byte - resultBytes, err = proto.Marshal(queryResult.(proto.Message)) - if err != nil { - chaincodeLogger.Errorf("Failed to get encode query result as bytes") - break + handler.cleanupQueryContext(txContext, iterID) + return nil, err + case queryResult == nil: + // nil response from iterator indicates end of query results + batch := pendingQueryResults.cut() + handler.cleanupQueryContext(txContext, iterID) + return &pb.QueryResponse{Results: batch, HasMore: false, Id: iterID}, nil + case pendingQueryResults.count == maxResultLimit: + // max number of results queued up, cut batch, then add current result to pending batch + batch := pendingQueryResults.cut() + if err := pendingQueryResults.add(queryResult); err != nil { + handler.cleanupQueryContext(txContext, iterID) + return nil, err + } + return &pb.QueryResponse{Results: batch, HasMore: true, Id: iterID}, nil + default: + if err := pendingQueryResults.add(queryResult); err != nil { + handler.cleanupQueryContext(txContext, iterID) + return nil, err + } } - - qresultBytes := pb.QueryResultBytes{ResultBytes: resultBytes} - queryResultsBytes = append(queryResultsBytes, &qresultBytes) } +} - if queryResult == nil || err != nil { - iter.Close() - handler.deleteQueryIterator(txContext, iterID) - if err != nil { - return nil, err - } +func (p *pendingQueryResult) cut() []*pb.QueryResultBytes { + batch := p.batch + p.batch = nil + p.count = 0 + return batch +} + +func (p *pendingQueryResult) add(queryResult commonledger.QueryResult) error { + queryResultBytes, err := proto.Marshal(queryResult.(proto.Message)) + if err != nil { + chaincodeLogger.Errorf("Failed to get encode query result as bytes") + return err } - return &pb.QueryResponse{Results: queryResultsBytes, HasMore: queryResult != nil, Id: iterID}, nil + p.batch = append(p.batch, &pb.QueryResultBytes{ResultBytes: queryResultBytes}) + p.count = len(p.batch) + return nil } // afterQueryStateNext handles a QUERY_STATE_NEXT request from the chaincode. @@ -831,8 +854,7 @@ func (handler *Handler) handleQueryStateNext(msg *pb.ChaincodeMessage) { errHandler := func(payload []byte, iter commonledger.ResultsIterator, errFmt string, errArgs ...interface{}) { if iter != nil { - iter.Close() - handler.deleteQueryIterator(txContext, queryStateNext.Id) + handler.cleanupQueryContext(txContext, queryStateNext.Id) } chaincodeLogger.Errorf(errFmt, errArgs...) serialSendMsg = &pb.ChaincodeMessage{Type: pb.ChaincodeMessage_ERROR, Payload: payload, Txid: msg.Txid, ChannelId: msg.ChannelId} @@ -931,8 +953,7 @@ func (handler *Handler) handleQueryStateClose(msg *pb.ChaincodeMessage) { iter := handler.getQueryIterator(txContext, queryStateClose.Id) if iter != nil { - iter.Close() - handler.deleteQueryIterator(txContext, queryStateClose.Id) + handler.cleanupQueryContext(txContext, queryStateClose.Id) } payload := &pb.QueryResponse{HasMore: false, Id: queryStateClose.Id} @@ -989,8 +1010,7 @@ func (handler *Handler) handleGetQueryResult(msg *pb.ChaincodeMessage) { errHandler := func(payload []byte, iter commonledger.ResultsIterator, errFmt string, errArgs ...interface{}) { if iter != nil { - iter.Close() - handler.deleteQueryIterator(txContext, iterID) + handler.cleanupQueryContext(txContext, iterID) } chaincodeLogger.Errorf(errFmt, errArgs...) serialSendMsg = &pb.ChaincodeMessage{Type: pb.ChaincodeMessage_ERROR, Payload: payload, Txid: msg.Txid, ChannelId: msg.ChannelId} @@ -1025,7 +1045,8 @@ func (handler *Handler) handleGetQueryResult(msg *pb.ChaincodeMessage) { return } - handler.putQueryIterator(txContext, iterID, executeIter) + handler.initializeQueryContext(txContext, iterID, executeIter) + var payload *pb.QueryResponse payload, err = getQueryResponse(handler, txContext, executeIter, iterID) if err != nil { @@ -1087,8 +1108,7 @@ func (handler *Handler) handleGetHistoryForKey(msg *pb.ChaincodeMessage) { errHandler := func(payload []byte, iter commonledger.ResultsIterator, errFmt string, errArgs ...interface{}) { if iter != nil { - iter.Close() - handler.deleteQueryIterator(txContext, iterID) + handler.cleanupQueryContext(txContext, iterID) } chaincodeLogger.Errorf(errFmt, errArgs...) serialSendMsg = &pb.ChaincodeMessage{Type: pb.ChaincodeMessage_ERROR, Payload: payload, Txid: msg.Txid, ChannelId: msg.ChannelId} @@ -1115,7 +1135,7 @@ func (handler *Handler) handleGetHistoryForKey(msg *pb.ChaincodeMessage) { return } - handler.putQueryIterator(txContext, iterID, historyIter) + handler.initializeQueryContext(txContext, iterID, historyIter) var payload *pb.QueryResponse payload, err = getQueryResponse(handler, txContext, historyIter, iterID) diff --git a/core/chaincode/handler_test.go b/core/chaincode/handler_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a358e9614b3e9fc18abbef40eb58c8328e18d340 --- /dev/null +++ b/core/chaincode/handler_test.go @@ -0,0 +1,109 @@ +/* +Copyright IBM Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package chaincode + +import ( + "fmt" + "math" + "testing" + + "github.com/hyperledger/fabric/common/ledger" + "github.com/hyperledger/fabric/protos/ledger/queryresult" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetQueryResponse(t *testing.T) { + + queryResult := &queryresult.KV{ + Key: "key", + Namespace: "namespace", + Value: []byte("value"), + } + + // test various boundry cases around maxResultLimit + testCases := []struct { + expectedResultCount int + expectedHasMoreCount int + }{ + {0, 0}, + {1, 0}, + {10, 0}, + {maxResultLimit - 2, 0}, + {maxResultLimit - 1, 0}, + {maxResultLimit, 0}, + {maxResultLimit + 1, 1}, + {maxResultLimit + 2, 1}, + {int(math.Floor(maxResultLimit * 1.5)), 1}, + {maxResultLimit * 2, 1}, + {10*maxResultLimit - 2, 9}, + {10*maxResultLimit - 1, 9}, + {10 * maxResultLimit, 9}, + {10*maxResultLimit + 1, 10}, + {10*maxResultLimit + 2, 10}, + } + + for _, tc := range testCases { + handler := &Handler{} + transactionContext := &transactionContext{ + queryIteratorMap: make(map[string]ledger.ResultsIterator), + pendingQueryResults: make(map[string]*pendingQueryResult), + } + queryID := "test" + t.Run(fmt.Sprintf("%d", tc.expectedResultCount), func(t *testing.T) { + resultsIterator := &MockResultsIterator{} + handler.initializeQueryContext(transactionContext, queryID, resultsIterator) + if tc.expectedResultCount > 0 { + resultsIterator.On("Next").Return(queryResult, nil).Times(tc.expectedResultCount) + } + resultsIterator.On("Next").Return(nil, nil).Once() + resultsIterator.On("Close").Return().Once() + totalResultCount := 0 + for hasMoreCount := 0; hasMoreCount <= tc.expectedHasMoreCount; hasMoreCount++ { + queryResponse, _ := getQueryResponse(handler, transactionContext, resultsIterator, queryID) + assert.NotNil(t, queryResponse.GetResults()) + if queryResponse.GetHasMore() { + t.Logf("Got %d results and more are expected.", len(queryResponse.GetResults())) + } else { + t.Logf("Got %d results and no more are expected.", len(queryResponse.GetResults())) + } + + switch { + case hasMoreCount < tc.expectedHasMoreCount: + // max limit sized batch retrieved, more expected + assert.True(t, queryResponse.GetHasMore()) + assert.Len(t, queryResponse.GetResults(), maxResultLimit) + default: + // remainder retrieved, no more expected + assert.Len(t, queryResponse.GetResults(), tc.expectedResultCount-totalResultCount) + assert.False(t, queryResponse.GetHasMore()) + + } + + totalResultCount += len(queryResponse.GetResults()) + } + resultsIterator.AssertExpectations(t) + }) + } + +} + +type MockResultsIterator struct { + mock.Mock +} + +func (m *MockResultsIterator) Next() (ledger.QueryResult, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(ledger.QueryResult), args.Error(1) +} + +func (m *MockResultsIterator) Close() { + m.Called() +}