Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: aggregating proofs #191

Merged
merged 15 commits into from
Nov 15, 2024
61 changes: 52 additions & 9 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ func (a *Aggregator) tryAggregateProofs(ctx context.Context, prover ProverInterf
tmpLogger.Infof("Proof ID for aggregated proof: %v", *proof.ProofID)
tmpLogger = tmpLogger.WithFields("proofId", *proof.ProofID)

recursiveProof, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
recursiveProof, _, _, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get aggregated proof from prover, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
Expand Down Expand Up @@ -1121,7 +1121,7 @@ func (a *Aggregator) getAndLockBatchToProve(
// Not found, so it it not possible to verify the batch yet
if sequence == nil || errors.Is(err, entities.ErrNotFound) {
tmpLogger.Infof("Sequencing event for batch %d has not been synced yet, "+
"so it is not possible to verify it yet. Waiting...", batchNumberToVerify)
"so it is not possible to verify it yet. Waiting ...", batchNumberToVerify)

return nil, nil, nil, state.ErrNotFound
}
Expand All @@ -1138,7 +1138,7 @@ func (a *Aggregator) getAndLockBatchToProve(
return nil, nil, nil, err
} else if errors.Is(err, entities.ErrNotFound) {
a.logger.Infof("Virtual batch %d has not been synced yet, "+
"so it is not possible to verify it yet. Waiting...", batchNumberToVerify)
"so it is not possible to verify it yet. Waiting ...", batchNumberToVerify)
return nil, nil, nil, state.ErrNotFound
}

Expand All @@ -1163,21 +1163,43 @@ func (a *Aggregator) getAndLockBatchToProve(
virtualBatch.L1InfoRoot = &l1InfoRoot
}

// Ensure the old acc input hash is in memory
oldAccInputHash := a.getAccInputHash(batchNumberToVerify - 1)
if oldAccInputHash == (common.Hash{}) && batchNumberToVerify > 1 {
tmpLogger.Warnf("AccInputHash for batch - 1 (%d) is not in memory. Waiting ...", batchNumberToVerify-1)
return nil, nil, nil, state.ErrNotFound
}

forcedBlockHashL1 := rpcBatch.ForcedBlockHashL1()
l1InfoRoot = *virtualBatch.L1InfoRoot

if batchNumberToVerify == 1 {
l1Block, err := a.l1Syncr.GetL1BlockByNumber(ctx, virtualBatch.BlockNumber)
if err != nil {
a.logger.Errorf("Error getting l1 block: %v", err)
return nil, nil, nil, err
}

forcedBlockHashL1 = l1Block.ParentHash
l1InfoRoot = rpcBatch.GlobalExitRoot()
}

// Calculate acc input hash as the RPC is not returning the correct one at the moment
accInputHash := cdkcommon.CalculateAccInputHash(
a.logger,
a.getAccInputHash(batchNumberToVerify-1),
oldAccInputHash,
virtualBatch.BatchL2Data,
*virtualBatch.L1InfoRoot,
l1InfoRoot,
uint64(sequence.Timestamp.Unix()),
rpcBatch.LastCoinbase(),
rpcBatch.ForcedBlockHashL1(),
forcedBlockHashL1,
)
// Store the acc input hash
a.setAccInputHash(batchNumberToVerify, accInputHash)

// Log params to calculate acc input hash
a.logger.Debugf("Calculated acc input hash for batch %d: %v", batchNumberToVerify, accInputHash)
a.logger.Debugf("OldAccInputHash: %v", oldAccInputHash)
a.logger.Debugf("L1InfoRoot: %v", virtualBatch.L1InfoRoot)
// a.logger.Debugf("LastL2BLockTimestamp: %v", rpcBatch.LastL2BLockTimestamp())
a.logger.Debugf("TimestampLimit: %v", uint64(sequence.Timestamp.Unix()))
Expand All @@ -1196,7 +1218,7 @@ func (a *Aggregator) getAndLockBatchToProve(
AccInputHash: accInputHash,
L1InfoTreeIndex: rpcBatch.L1InfoTreeIndex(),
L1InfoRoot: *virtualBatch.L1InfoRoot,
Timestamp: time.Unix(int64(rpcBatch.LastL2BLockTimestamp()), 0),
Timestamp: sequence.Timestamp,
GlobalExitRoot: rpcBatch.GlobalExitRoot(),
ChainID: a.cfg.ChainID,
ForkID: a.cfg.ForkId,
Expand Down Expand Up @@ -1325,7 +1347,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt

tmpLogger = tmpLogger.WithFields("proofId", *proof.ProofID)

resGetProof, stateRoot, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
resGetProof, stateRoot, accInputHash, err := prover.WaitRecursiveProof(ctx, *proof.ProofID)
if err != nil {
err = fmt.Errorf("failed to get proof from prover, %w", err)
tmpLogger.Error(FirstToUpper(err.Error()))
Expand All @@ -1346,6 +1368,20 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover ProverInt
tmpLogger.Infof("State root sanity check for batch %d passed", batchToProve.BatchNumber)
}

// Sanity Check: acc input hash from the proof must match the one from the batch
if a.cfg.BatchProofSanityCheckEnabled && (accInputHash != common.Hash{}) &&
(accInputHash != batchToProve.AccInputHash) {
for {
tmpLogger.Errorf("Acc input hash from the proof does not match the expected for "+
"batch %d: Proof = [%s] Expected = [%s]",
batchToProve.BatchNumber, accInputHash.String(), batchToProve.AccInputHash.String(),
)
time.Sleep(a.cfg.RetryTime.Duration)
}
} else {
tmpLogger.Infof("Acc input hash sanity check for batch %d passed", batchToProve.BatchNumber)
}

proof.Proof = resGetProof

// NOTE(pg): the defer func is useless from now on, use a different variable
Expand Down Expand Up @@ -1505,10 +1541,17 @@ func (a *Aggregator) buildInputProver(
}
}

// Ensure the old acc input hash is in memory
oldAccInputHash := a.getAccInputHash(batchToVerify.BatchNumber - 1)
if oldAccInputHash == (common.Hash{}) && batchToVerify.BatchNumber > 1 {
a.logger.Warnf("AccInputHash for batch - 1 (%d) is not in memory. Waiting ...", batchToVerify.BatchNumber-1)
return nil, fmt.Errorf("acc input hash for batch - 1 (%d) is not in memory", batchToVerify.BatchNumber-1)
}

inputProver := &prover.StatelessInputProver{
PublicInputs: &prover.StatelessPublicInputs{
Witness: witness,
OldAccInputHash: a.getAccInputHash(batchToVerify.BatchNumber - 1).Bytes(),
OldAccInputHash: oldAccInputHash.Bytes(),
OldBatchNum: batchToVerify.BatchNumber - 1,
ChainId: batchToVerify.ChainID,
ForkId: batchToVerify.ForkID,
Expand Down
18 changes: 9 additions & 9 deletions aggregator/aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ func Test_tryAggregateProofs(t *testing.T) {
Return(nil).
Once()
m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once()
m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin)
m.stateMock.
On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx).
Expand Down Expand Up @@ -1172,7 +1172,7 @@ func Test_tryAggregateProofs(t *testing.T) {
Return(nil).
Once()
m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once()
m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin)
m.stateMock.
On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), &proof1, dbTx).
Expand Down Expand Up @@ -1220,7 +1220,7 @@ func Test_tryAggregateProofs(t *testing.T) {
Return(nil).
Once()
m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once()
m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(errTest).Once()
dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once()
m.stateMock.On("BeginStateTransaction", mock.MatchedBy(matchAggregatorCtxFn)).Return(dbTx, nil).Once().NotBefore(lockProofsTxBegin)
Expand Down Expand Up @@ -1280,7 +1280,7 @@ func Test_tryAggregateProofs(t *testing.T) {
Return(nil).
Once()
m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once()
m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once()
m.stateMock.On("AddGeneratedProof", mock.MatchedBy(matchProverCtxFn), mock.Anything, dbTx).Return(errTest).Once()
dbTx.On("Rollback", mock.MatchedBy(matchProverCtxFn)).Return(nil).Once()
Expand Down Expand Up @@ -1343,7 +1343,7 @@ func Test_tryAggregateProofs(t *testing.T) {
Once()

m.proverMock.On("AggregatedProof", proof1.Proof, proof2.Proof).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once()
m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchProverCtxFn), proof1.BatchNumber, proof2.BatchNumberFinal, dbTx).Return(nil).Once()
expectedInputProver := map[string]interface{}{
"recursive_proof_1": proof1.Proof,
Expand Down Expand Up @@ -1642,7 +1642,7 @@ func Test_tryGenerateBatchProof(t *testing.T) {
require.NoError(err)

m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once()
m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(nil).Once()
},
asserts: func(result bool, a *Aggregator, err error) {
Expand Down Expand Up @@ -1716,7 +1716,7 @@ func Test_tryGenerateBatchProof(t *testing.T) {
m.synchronizerMock.On("GetVirtualBatchByBatchNumber", mock.Anything, lastVerifiedBatchNum+1).Return(&virtualBatch, nil).Once()

m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, errTest).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return("", common.Hash{}, common.Hash{}, errTest).Once()
m.stateMock.On("DeleteGeneratedProofs", mock.MatchedBy(matchAggregatorCtxFn), batchToProve.BatchNumber, batchToProve.BatchNumber, nil).Return(errTest).Once()
},
asserts: func(result bool, a *Aggregator, err error) {
Expand Down Expand Up @@ -1792,7 +1792,7 @@ func Test_tryGenerateBatchProof(t *testing.T) {
require.NoError(err)

m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once()
m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run(
func(args mock.Arguments) {
proof, ok := args[1].(*state.Proof)
Expand Down Expand Up @@ -1885,7 +1885,7 @@ func Test_tryGenerateBatchProof(t *testing.T) {
require.NoError(err)

m.proverMock.On("BatchProof", expectedInputProver).Return(&proofID, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, nil).Once()
m.proverMock.On("WaitRecursiveProof", mock.MatchedBy(matchProverCtxFn), proofID).Return(recursiveProof, common.Hash{}, common.Hash{}, nil).Once()
m.etherman.On("GetLatestVerifiedBatchNum").Return(uint64(42), errTest).Once()
m.stateMock.On("UpdateGeneratedProof", mock.MatchedBy(matchAggregatorCtxFn), mock.Anything, nil).Run(
func(args mock.Arguments) {
Expand Down
2 changes: 1 addition & 1 deletion aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type ProverInterface interface {
BatchProof(input *prover.StatelessInputProver) (*string, error)
AggregatedProof(inputProof1, inputProof2 string) (*string, error)
FinalProof(inputProof string, aggregatorAddr string) (*string, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error)
WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, common.Hash, error)
WaitFinalProof(ctx context.Context, proofID string) (*prover.FinalProof, error)
}

Expand Down
21 changes: 15 additions & 6 deletions aggregator/mocks/mock_prover.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 18 additions & 13 deletions aggregator/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
)

const (
stateRootStartIndex = 19
stateRootFinalIndex = stateRootStartIndex + 8
stateRootStartIndex = 19
stateRootFinalIndex = stateRootStartIndex + 8
accInputHashStartIndex = 27
accInputHashFinalIndex = accInputHashStartIndex + 8
)

var (
Expand Down Expand Up @@ -282,30 +284,36 @@ func (p *Prover) CancelProofRequest(proofID string) error {

// WaitRecursiveProof waits for a recursive proof to be generated by the prover
// and returns it.
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, error) {
func (p *Prover) WaitRecursiveProof(ctx context.Context, proofID string) (string, common.Hash, common.Hash, error) {
res, err := p.waitProof(ctx, proofID)
if err != nil {
return "", common.Hash{}, err
return "", common.Hash{}, common.Hash{}, err
}

resProof, ok := res.Proof.(*GetProofResponse_RecursiveProof)
if !ok {
return "", common.Hash{}, fmt.Errorf(
return "", common.Hash{}, common.Hash{}, fmt.Errorf(
"%w, wanted %T, got %T",
ErrBadProverResponse, &GetProofResponse_RecursiveProof{}, res.Proof,
)
}

sr, err := GetStateRootFromProof(p.logger, resProof.RecursiveProof)
sr, err := GetSanityCheckHashFromProof(p.logger, resProof.RecursiveProof, stateRootStartIndex, stateRootFinalIndex)
if err != nil && sr != (common.Hash{}) {
p.logger.Errorf("Error getting state root from proof: %v", err)
}

accInputHash, err := GetSanityCheckHashFromProof(p.logger, resProof.RecursiveProof,
accInputHashStartIndex, accInputHashFinalIndex)
if err != nil && accInputHash != (common.Hash{}) {
p.logger.Errorf("Error getting acc input hash from proof: %v", err)
}

if sr == (common.Hash{}) {
p.logger.Info("Recursive proof does not contain state root. Possibly mock prover is in use.")
}

return resProof.RecursiveProof, sr, nil
return resProof.RecursiveProof, sr, accInputHash, nil
}

// WaitFinalProof waits for the final proof to be generated by the prover and
Expand Down Expand Up @@ -395,11 +403,8 @@ func (p *Prover) call(req *AggregatorMessage) (*ProverMessage, error) {
return res, nil
}

// GetStateRootFromProof returns the state root from the proof.
func GetStateRootFromProof(logger *log.Logger, proof string) (common.Hash, error) {
// Log received proof
logger.Debugf("Received proof to get SR from: %s", proof)

// GetSanityCheckHashFromProof returns info from the proof
func GetSanityCheckHashFromProof(logger *log.Logger, proof string, startIndex, endIndex int) (common.Hash, error) {
type Publics struct {
Publics []string `mapstructure:"publics"`
}
Expand All @@ -420,7 +425,7 @@ func GetStateRootFromProof(logger *log.Logger, proof string) (common.Hash, error
v [8]uint64
j = 0
)
for i := stateRootStartIndex; i < stateRootFinalIndex; i++ {
for i := startIndex; i < endIndex; i++ {
u64, err := strconv.ParseInt(publics.Publics[i], 10, 64)
if err != nil {
logger.Fatal(err)
Expand Down
6 changes: 4 additions & 2 deletions aggregator/prover/prover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
)

const (
dir = "../../test/vectors/proofs"
dir = "../../test/vectors/proofs"
stateRootStartIndex = 19
stateRootFinalIndex = stateRootStartIndex + 8
)

type TestStateRoot struct {
Expand Down Expand Up @@ -40,7 +42,7 @@ func TestCalculateStateRoots(t *testing.T) {
require.NoError(t, err)

// Get the state root from the batch proof
fileStateRoot, err := prover.GetStateRootFromProof(log.GetDefaultLogger(), string(data))
fileStateRoot, err := prover.GetSanityCheckHashFromProof(log.GetDefaultLogger(), string(data), stateRootStartIndex, stateRootFinalIndex)
require.NoError(t, err)

// Get the expected state root
Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/0xPolygon/cdk-data-availability v0.0.10
github.com/0xPolygon/cdk-rpc v0.0.0-20241004114257-6c3cb6eebfb6
github.com/0xPolygon/zkevm-ethtx-manager v0.2.1
github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.5
github.com/0xPolygonHermez/zkevm-synchronizer-l1 v1.0.6
github.com/ethereum/go-ethereum v1.14.8
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
github.com/hermeznetwork/tracerr v0.3.2
Expand Down Expand Up @@ -40,7 +40,6 @@ require (
)

require (
github.com/0xPolygonHermez/zkevm-data-streamer v0.2.7 // indirect
github.com/DataDog/zstd v1.5.6 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/StackExchange/wmi v1.2.1 // indirect
Expand Down
Loading
Loading