diff --git a/core/crypto/auth/signer.go b/core/crypto/auth/signer.go index 9875d404f..6d435beae 100644 --- a/core/crypto/auth/signer.go +++ b/core/crypto/auth/signer.go @@ -7,14 +7,14 @@ import ( "fmt" "io" - "github.com/kwilteam/kwil-db/core/crypto" "golang.org/x/crypto/sha3" + + "github.com/kwilteam/kwil-db/core/crypto" + "github.com/kwilteam/kwil-db/core/utils" ) // Signature is a signature with a designated AuthType, which should // be used to determine how to verify the signature. -// It seems a bit weird to have a field "Signature" inside a struct called "Signature", -// but I am keeping it like this for compatibility with the old code. type Signature struct { // Data is the raw signature bytes Data []byte `json:"sig"` @@ -23,25 +23,38 @@ type Signature struct { Type string `json:"type"` } -func (s Signature) MarshalBinary() ([]byte, error) { - buf := new(bytes.Buffer) - if err := binary.Write(buf, binary.LittleEndian, uint32(len(s.Data))); err != nil { - return nil, fmt.Errorf("failed to write signature length: %w", err) +var _ io.WriterTo = Signature{} + +func (s Signature) WriteTo(w io.Writer) (int64, error) { + cw := utils.NewCountingWriter(w) + if err := binary.Write(cw, binary.LittleEndian, uint32(len(s.Data))); err != nil { + return cw.Written(), fmt.Errorf("failed to write signature length: %w", err) } - if err := binary.Write(buf, binary.LittleEndian, s.Data); err != nil { - return nil, fmt.Errorf("failed to write signature data: %w", err) + if err := binary.Write(cw, binary.LittleEndian, s.Data); err != nil { + return cw.Written(), fmt.Errorf("failed to write signature data: %w", err) } - if err := binary.Write(buf, binary.LittleEndian, uint32(len(s.Type))); err != nil { - return nil, fmt.Errorf("failed to write signature type length: %w", err) + if err := binary.Write(cw, binary.LittleEndian, uint32(len(s.Type))); err != nil { + return cw.Written(), fmt.Errorf("failed to write signature type length: %w", err) } - if err := binary.Write(buf, binary.LittleEndian, []byte(s.Type)); err != nil { - return nil, fmt.Errorf("failed to write signature type: %w", err) + if err := binary.Write(cw, binary.LittleEndian, []byte(s.Type)); err != nil { + return cw.Written(), fmt.Errorf("failed to write signature type: %w", err) } + return cw.Written(), nil +} + +func (s Signature) MarshalBinary() ([]byte, error) { + buf := new(bytes.Buffer) + s.WriteTo(buf) // does not error with a bytes.Buffer as the Writer return buf.Bytes(), nil } +func (s Signature) Bytes() []byte { + b, _ := s.MarshalBinary() // does not error + return b +} + func (s *Signature) UnmarshalBinary(data []byte) error { r := bytes.NewReader(data) n, err := s.ReadFrom(r) diff --git a/core/types/block.go b/core/types/block.go index 2b0b65b88..cc1014ec6 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -62,8 +62,8 @@ type BlockHeader struct { type Block struct { Header *BlockHeader - Txns [][]byte - Signature []byte // Signature is the block producer's signature (leader in our model) + Txns [][]byte // TODO: convert to []*Transaction + Signature []byte // Signature is the block producer's signature (leader in our model) } func (b *Block) Hash() Hash { diff --git a/core/types/transaction.go b/core/types/transaction.go index 95c460dc6..00eebe3d0 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -12,6 +12,7 @@ import ( "github.com/kwilteam/kwil-db/core/crypto" "github.com/kwilteam/kwil-db/core/crypto/auth" + "github.com/kwilteam/kwil-db/core/utils" ) type ResultBroadcastTx struct { @@ -48,12 +49,23 @@ type Transaction struct { Sender HexBytes `json:"sender"` strictUnmarshal bool + // cachedHash *Hash // maybe maybe maybe... this would require a mutex or careful use } func (t *Transaction) StrictUnmarshal() { t.strictUnmarshal = true } +// Hash gives the hash of the transaction that is the unique identifier for the +// transaction. +func (t *Transaction) Hash() (Hash, error) { + raw, err := t.MarshalBinary() + if err != nil { + return Hash{}, err + } + return HashBytes(raw), nil +} + // TransactionBody is the body of a transaction that gets included in the // signature. This type implements json.Marshaler and json.Unmarshaler to ensure // that the Fee field is represented as a string in JSON rather than a number. @@ -259,6 +271,14 @@ func (t *TransactionBody) SerializeMsg(mst SignedMsgSerializationType) ([]byte, return nil, errors.New("invalid serialization type") } +var _ io.WriterTo = (*Transaction)(nil) + +func (t *Transaction) WriteTo(w io.Writer) (int64, error) { + cw := utils.NewCountingWriter(w) + err := t.serialize(cw) + return cw.Written(), err +} + var _ encoding.BinaryMarshaler = (*Transaction)(nil) // MarshalBinary produces the full binary serialization of the transaction, @@ -276,9 +296,9 @@ var _ io.ReaderFrom = (*Transaction)(nil) func (t *Transaction) ReadFrom(r io.Reader) (int64, error) { n, err := t.deserialize(r) if err != nil { - return int64(n), err + return n, err } - return int64(n), nil + return n, nil } var _ encoding.BinaryUnmarshaler = (*Transaction)(nil) @@ -292,7 +312,7 @@ func (t *Transaction) UnmarshalBinary(data []byte) error { if !t.strictUnmarshal { return nil } - if n != len(data) { + if n != int64(len(data)) { return errors.New("failed to read all") } if r.Len() != 0 { @@ -301,96 +321,102 @@ func (t *Transaction) UnmarshalBinary(data []byte) error { return nil } +func (tb TransactionBody) Bytes() []byte { + b, _ := tb.MarshalBinary() // does not error + return b +} + var _ encoding.BinaryMarshaler = (*TransactionBody)(nil) -func (tb *TransactionBody) MarshalBinary() ([]byte, error) { +func (tb TransactionBody) MarshalBinary() ([]byte, error) { buf := new(bytes.Buffer) + tb.WriteTo(buf) // no error with bytes.Buffer + return buf.Bytes(), nil +} + +var _ io.WriterTo = TransactionBody{} +func (tb TransactionBody) WriteTo(w io.Writer) (int64, error) { + cw := utils.NewCountingWriter(w) // Description Length + Description - if err := writeString(buf, tb.Description); err != nil { - return nil, fmt.Errorf("failed to write transaction body description: %w", err) + if err := writeString(cw, tb.Description); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction body description: %w", err) } // serialized Payload - if err := writeBytes(buf, tb.Payload); err != nil { - return nil, fmt.Errorf("failed to write transaction body payload: %w", err) + if err := writeBytes(cw, tb.Payload); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction body payload: %w", err) } // PayloadType payloadType := tb.PayloadType.String() - if err := writeString(buf, payloadType); err != nil { - return nil, fmt.Errorf("failed to write transaction body payload type: %w", err) + if err := writeString(cw, payloadType); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction body payload type: %w", err) } // Fee (big.Int) - if err := writeBigInt(buf, tb.Fee); err != nil { - return nil, fmt.Errorf("failed to write transaction fee: %w", err) + if err := writeBigInt(cw, tb.Fee); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction fee: %w", err) } // Nonce - if err := binary.Write(buf, binary.LittleEndian, tb.Nonce); err != nil { - return nil, fmt.Errorf("failed to write transaction body nonce: %w", err) + if err := binary.Write(cw, binary.LittleEndian, tb.Nonce); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction body nonce: %w", err) } // ChainID - if err := writeString(buf, tb.ChainID); err != nil { - return nil, fmt.Errorf("failed to write transaction body chain ID: %w", err) + if err := writeString(cw, tb.ChainID); err != nil { + return cw.Written(), fmt.Errorf("failed to write transaction body chain ID: %w", err) } - - return buf.Bytes(), nil + return cw.Written(), nil } var _ io.ReaderFrom = (*TransactionBody)(nil) func (tb *TransactionBody) ReadFrom(r io.Reader) (int64, error) { - var n int + cr := utils.NewCountingReader(r) + // Description Length + Description - desc, err := readString(r) + desc, err := readString(cr) if err != nil { - return int64(n), fmt.Errorf("failed to read transaction body description: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body description: %w", err) } tb.Description = desc - n += 4 + len(desc) // serialized Payload - payload, err := readBytes(r) + payload, err := readBytes(cr) if err != nil { - return int64(n), fmt.Errorf("failed to read transaction body payload: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body payload: %w", err) } tb.Payload = payload - n += 4 + len(payload) // PayloadType - payloadType, err := readString(r) + payloadType, err := readString(cr) if err != nil { - return int64(n), fmt.Errorf("failed to read transaction body payload type: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body payload type: %w", err) } tb.PayloadType = PayloadType(payloadType) - n += 4 + len(payloadType) // Fee (big.Int) - b, ni, err := readBigInt(r) + b, _, err := readBigInt(cr) if err != nil { - return int64(n), fmt.Errorf("failed to read transaction body fee: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body fee: %w", err) } tb.Fee = b // may be nil - n += ni // Nonce - if err := binary.Read(r, binary.LittleEndian, &tb.Nonce); err != nil { - return int64(n), fmt.Errorf("failed to read transaction body nonce: %w", err) + if err := binary.Read(cr, binary.LittleEndian, &tb.Nonce); err != nil { + return cr.ReadCount(), fmt.Errorf("failed to read transaction body nonce: %w", err) } - n += 8 // ChainID - chainID, err := readString(r) + chainID, err := readString(cr) if err != nil { - return int64(n), fmt.Errorf("failed to read transaction body chain ID: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body chain ID: %w", err) } tb.ChainID = chainID - n += 4 + len(chainID) - return int64(n), nil + return cr.ReadCount(), nil } var _ encoding.BinaryUnmarshaler = (*TransactionBody)(nil) @@ -417,26 +443,21 @@ func (tb *TransactionBody) UnmarshalBinary(data []byte) error { } func (t *Transaction) serialize(w io.Writer) (err error) { + if t.Body == nil { + return errors.New("missing transaction body") + } + // Tx Signature var sigBytes []byte if t.Signature != nil { - if sigBytes, err = t.Signature.MarshalBinary(); err != nil { - return fmt.Errorf("failed to marshal transaction signature: %w", err) - } + sigBytes = t.Signature.Bytes() } if err := writeBytes(w, sigBytes); err != nil { return fmt.Errorf("failed to write transaction signature: %w", err) } // Tx Body - if t.Body == nil { - return errors.New("missing transaction body") - } - txBodyBytes, err := t.Body.MarshalBinary() - if err != nil { - return fmt.Errorf("failed to marshal transaction body: %w", err) - } - if _, err := w.Write(txBodyBytes); err != nil { + if _, err := t.Body.WriteTo(w); err != nil { return fmt.Errorf("failed to write transaction body: %w", err) } /*var txBodyBytes []byte @@ -463,38 +484,35 @@ func (t *Transaction) serialize(w io.Writer) (err error) { return nil } -func (t *Transaction) deserialize(r io.Reader) (int, error) { - var n int +func (t *Transaction) deserialize(r io.Reader) (int64, error) { + cr := utils.NewCountingReader(r) // Signature - sigBytes, err := readBytes(r) + sigBytes, err := readBytes(cr) if err != nil { - return n, fmt.Errorf("failed to read transaction signature: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction signature: %w", err) } - n += 4 + len(sigBytes) if len(sigBytes) != 0 { var signature auth.Signature if err = signature.UnmarshalBinary(sigBytes); err != nil { - return 0, fmt.Errorf("failed to unmarshal transaction signature: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to unmarshal transaction signature: %w", err) } t.Signature = &signature } // TxBody var body TransactionBody - bodyLen, err := body.ReadFrom(r) + _, err = body.ReadFrom(cr) if err != nil { - return 0, fmt.Errorf("failed to read transaction body: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction body: %w", err) } t.Body = &body - n += int(bodyLen) /* if we need to support nil body... - bodyBytes, err := readBytes(r) + bodyBytes, err := readBytes(cr) if err != nil { return 0, fmt.Errorf("failed to read transaction body: %w", err) } - n += 4 + len(bodyBytes) if len(bodyBytes) != 0 { var body TransactionBody body.StrictUnmarshal() @@ -505,22 +523,20 @@ func (t *Transaction) deserialize(r io.Reader) (int, error) { }*/ // SerializationType - serType, err := readString(r) + serType, err := readString(cr) if err != nil { - return 0, fmt.Errorf("failed to read transaction serialization type: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction serialization type: %w", err) } - n += 4 + len(serType) t.Serialization = SignedMsgSerializationType(serType) // Sender - senderBytes, err := readBytes(r) + senderBytes, err := readBytes(cr) if err != nil { - return 0, fmt.Errorf("failed to read transaction sender: %w", err) + return cr.ReadCount(), fmt.Errorf("failed to read transaction sender: %w", err) } - n += 4 + len(senderBytes) t.Sender = senderBytes - return n, nil + return cr.ReadCount(), nil } func writeBytes(w io.Writer, data []byte) error { diff --git a/core/utils/rw.go b/core/utils/rw.go new file mode 100644 index 000000000..dbd024eee --- /dev/null +++ b/core/utils/rw.go @@ -0,0 +1,47 @@ +package utils + +import "io" + +// CountingWriter wraps an io.Writer, adding a Written method to get the total +// bytes written over multiple calls to Write. This is helpful if the Writer +// passes through other functions that do not return the bytes written. +type CountingWriter struct { + w io.Writer + c int64 +} + +func NewCountingWriter(w io.Writer) *CountingWriter { + return &CountingWriter{w: w} +} + +func (cw *CountingWriter) Write(p []byte) (int, error) { + n, err := cw.w.Write(p) + cw.c += int64(n) + return n, err +} + +func (cw *CountingWriter) Written() int64 { + return cw.c +} + +// CountingReader wraps an io.Reader, adding a ReadCount method to get the total +// bytes read over multiple calls to Read. This is helpful if the Reader passes +// through other functions that do not return the bytes read. +type CountingReader struct { + r io.Reader + c int64 +} + +func NewCountingReader(r io.Reader) *CountingReader { + return &CountingReader{r: r} +} + +func (cr *CountingReader) Read(p []byte) (int, error) { + n, err := cr.r.Read(p) + cr.c += int64(n) + return n, err +} + +func (cr *CountingReader) ReadCount() int64 { + return cr.c +} diff --git a/node/block_processor/processor.go b/node/block_processor/processor.go index 2e9e5e8bd..7dc41b63b 100644 --- a/node/block_processor/processor.go +++ b/node/block_processor/processor.go @@ -3,7 +3,6 @@ package blockprocessor import ( "bytes" "context" - "crypto/sha256" "encoding/binary" "encoding/hex" "errors" @@ -136,15 +135,15 @@ func (bp *BlockProcessor) Rollback(ctx context.Context, height int64, appHash kt return nil } -func (bp *BlockProcessor) CheckTx(ctx context.Context, incomingTx []byte, recheck bool) error { - var err error - tx := &ktypes.Transaction{} - if err = tx.UnmarshalBinary(incomingTx); err != nil { - bp.log.Debug("Failed to unmarshal the transaction", "err", err) - return fmt.Errorf("failed to unmarshal the transaction: %w", err) +func (bp *BlockProcessor) CheckTx(ctx context.Context, tx *ktypes.Transaction, recheck bool) error { + rawTx, err := tx.MarshalBinary() + if err != nil { + return fmt.Errorf("invalid transaction: %v", err) // e.g. missing fields } + txHash := types.HashBytes(rawTx) - bp.log.Info("Check transaction", "Recheck", recheck, "Sender", hex.EncodeToString(tx.Sender), "PayloadType", tx.Body.PayloadType.String(), "Nonce", tx.Body.Nonce, "TxFee", tx.Body.Fee.String()) + bp.log.Info("Check transaction", "Recheck", recheck, "Hash", txHash, "Sender", hex.EncodeToString(tx.Sender), + "PayloadType", tx.Body.PayloadType.String(), "Nonce", tx.Body.Nonce, "TxFee", tx.Body.Fee.String()) if !recheck { // Verify the correct chain ID is set, if it is set. @@ -176,7 +175,6 @@ func (bp *BlockProcessor) CheckTx(ctx context.Context, incomingTx []byte, rechec return fmt.Errorf("failed to get identifier: %w", err) } - txHash := sha256.Sum256(incomingTx) err = bp.txapp.ApplyMempool(&common.TxContext{ Ctx: ctx, BlockContext: &common.BlockContext{ diff --git a/node/consensus/block.go b/node/consensus/block.go index b92983b6e..fc4ba0a7b 100644 --- a/node/consensus/block.go +++ b/node/consensus/block.go @@ -41,7 +41,7 @@ func (ce *ConsensusEngine) validateBlock(blk *ktypes.Block) error { return nil } -func (ce *ConsensusEngine) CheckTx(ctx context.Context, tx []byte) error { +func (ce *ConsensusEngine) CheckTx(ctx context.Context, tx *ktypes.Transaction) error { ce.mempoolMtx.Lock() defer ce.mempoolMtx.Unlock() @@ -107,7 +107,7 @@ func (ce *ConsensusEngine) commit(ctx context.Context) error { // remove transactions from the mempool for _, txn := range blkProp.blk.Txns { txHash := types.HashBytes(txn) // TODO: can this be saved instead of recalculating? - ce.mempool.Store(txHash, nil) + ce.mempool.Remove(txHash) } // TODO: reapply existing transaction (checkTX) diff --git a/node/consensus/interfaces.go b/node/consensus/interfaces.go index 7f464a453..aa711f6cc 100644 --- a/node/consensus/interfaces.go +++ b/node/consensus/interfaces.go @@ -20,7 +20,7 @@ type DB interface { type Mempool interface { PeekN(maxSize int) []types.NamedTx - Store(txid types.Hash, tx []byte) + Remove(txid types.Hash) } // BlockStore includes both txns and blocks @@ -42,7 +42,7 @@ type BlockProcessor interface { Rollback(ctx context.Context, height int64, appHash ktypes.Hash) error Close() error - CheckTx(ctx context.Context, tx []byte, recheck bool) error + CheckTx(ctx context.Context, tx *ktypes.Transaction, recheck bool) error GetValidators() []*ktypes.Validator } diff --git a/node/consensus/leader.go b/node/consensus/leader.go index 3c75c6e7e..ab82e2723 100644 --- a/node/consensus/leader.go +++ b/node/consensus/leader.go @@ -101,7 +101,15 @@ func (ce *ConsensusEngine) createBlockProposal() (*blockProposal, error) { nTxs := ce.mempool.PeekN(blockTxCount) var txns [][]byte for _, namedTx := range nTxs { - txns = append(txns, namedTx.Tx) + rawTx, err := namedTx.Tx.MarshalBinary() + if err != nil { // this is a bug + ce.log.Errorf("invalid transaction from mempool rejected", + "hash", namedTx.Hash, "error", err) + ce.mempool.Remove(namedTx.Hash) + continue + // return nil, fmt.Errorf("invalid transaction: %v", err) // e.g. nil/missing body + } + txns = append(txns, rawTx) } blk := ktypes.NewBlock(ce.state.lc.height+1, ce.state.lc.blkHash, ce.state.lc.appHash, ce.ValidatorSetHash(), time.Now(), txns) diff --git a/node/interfaces.go b/node/interfaces.go index 7336958da..8262d6d99 100644 --- a/node/interfaces.go +++ b/node/interfaces.go @@ -29,7 +29,7 @@ type ConsensusEngine interface { blkAnnouncer consensus.BlkAnnouncer, ackBroadcaster consensus.AckBroadcaster, blkRequester consensus.BlkRequester, stateResetter consensus.ResetStateBroadcaster, discoveryBroadcaster consensus.DiscoveryReqBroadcaster) error - CheckTx(ctx context.Context, tx []byte) error + CheckTx(ctx context.Context, tx *ktypes.Transaction) error } type SnapshotStore interface { diff --git a/node/mempool/mempool.go b/node/mempool/mempool.go index 7e8ed5a9b..c8d3ad7a8 100644 --- a/node/mempool/mempool.go +++ b/node/mempool/mempool.go @@ -1,8 +1,10 @@ package mempool import ( + "slices" "sync" + ktypes "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/node/types" ) @@ -10,13 +12,15 @@ import ( type Mempool struct { mtx sync.RWMutex - txns map[types.Hash][]byte + txns map[types.Hash]*ktypes.Transaction + txQ []types.NamedTx fetching map[types.Hash]bool + // acctTxns map[string][]types.NamedTx } func New() *Mempool { return &Mempool{ - txns: make(map[types.Hash][]byte), + txns: make(map[types.Hash]*ktypes.Transaction), fetching: make(map[types.Hash]bool), } } @@ -28,15 +32,38 @@ func (mp *Mempool) Have(txid types.Hash) bool { // this is racy return have } -func (mp *Mempool) Store(txid types.Hash, raw []byte) { +func (mp *Mempool) Remove(txid types.Hash) { mp.mtx.Lock() defer mp.mtx.Unlock() - if raw == nil { - delete(mp.txns, txid) - delete(mp.fetching, txid) + mp.remove(txid) +} + +func (mp *Mempool) remove(txid types.Hash) { + idx := slices.IndexFunc(mp.txQ, func(a types.NamedTx) bool { + return a.Hash == txid + }) + if idx == -1 { + return + } + mp.txQ = slices.Delete(mp.txQ, idx, idx+1) // remove txQ[idx] + delete(mp.txns, txid) +} + +func (mp *Mempool) Store(txid types.Hash, tx *ktypes.Transaction) { + mp.mtx.Lock() + defer mp.mtx.Unlock() + delete(mp.fetching, txid) + + if tx == nil { // legacy semantics for removal + mp.remove(txid) return } - mp.txns[txid] = raw + + mp.txns[txid] = tx + mp.txQ = append(mp.txQ, types.NamedTx{ + Hash: txid, + Tx: tx, + }) } func (mp *Mempool) PreFetch(txid types.Hash) bool { // probably make node business @@ -57,43 +84,33 @@ func (mp *Mempool) PreFetch(txid types.Hash) bool { // probably make node busine func (mp *Mempool) Size() int { mp.mtx.RLock() defer mp.mtx.RUnlock() - return len(mp.txns) + return len(mp.txQ) } -func (mp *Mempool) Get(txid types.Hash) []byte { +func (mp *Mempool) Get(txid types.Hash) *ktypes.Transaction { mp.mtx.RLock() defer mp.mtx.RUnlock() return mp.txns[txid] } -func (mp *Mempool) ReapN(n int) ([]types.Hash, [][]byte) { +// ReapN extracts the first n transactions in the queue +func (mp *Mempool) ReapN(n int) []types.NamedTx { mp.mtx.Lock() defer mp.mtx.Unlock() - n = min(n, len(mp.txns)) - txids := make([]types.Hash, 0, n) - txns := make([][]byte, 0, n) - for txid, rawTx := range mp.txns { - delete(mp.txns, txid) - txids = append(txids, txid) - txns = append(txns, rawTx) - if len(txids) == cap(txids) { - break - } + n = min(n, len(mp.txQ)) + txns := slices.Clone(mp.txQ[:n]) + mp.txQ = mp.txQ[n:] + for _, tx := range txns { + delete(mp.txns, tx.Hash) } - return txids, txns + return txns } func (mp *Mempool) PeekN(n int) []types.NamedTx { mp.mtx.RLock() defer mp.mtx.RUnlock() n = min(n, len(mp.txns)) - txns := make([]types.NamedTx, 0, n) - for txid, rawTx := range mp.txns { - txns = append(txns, types.NamedTx{Hash: txid, Tx: rawTx}) - if len(txns) == n { - break - } - } - + txns := make([]types.NamedTx, n) + copy(txns, mp.txQ) return txns } diff --git a/node/mempool/mempool_test.go b/node/mempool/mempool_test.go new file mode 100644 index 000000000..6b27b5c29 --- /dev/null +++ b/node/mempool/mempool_test.go @@ -0,0 +1,122 @@ +package mempool + +import ( + "math/big" + "testing" + + "github.com/kwilteam/kwil-db/core/crypto/auth" + ktypes "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/node/types" + "github.com/stretchr/testify/assert" +) + +func newTx(nonce uint64, sender string) *ktypes.Transaction { + return &ktypes.Transaction{ + Signature: &auth.Signature{}, + Body: &ktypes.TransactionBody{ + Description: "test", + Payload: []byte(`random payload`), + Fee: big.NewInt(0), + Nonce: nonce, + }, + Sender: []byte(sender), + } +} + +func Test_MempoolRemove(t *testing.T) { + m := New() + + // Setup test transactions + tx1 := types.NamedTx{ + Hash: types.Hash{1, 2, 3}, + Tx: newTx(1, "A"), + } + tx2 := types.NamedTx{ + Hash: types.Hash{4, 5, 6}, + Tx: newTx(2, "B"), + } + + // Add transactions to mempool + m.Store(tx1.Hash, tx1.Tx) + m.Store(tx2.Hash, tx2.Tx) + + // Test removing existing transaction + m.Remove(tx1.Hash) + assert.Len(t, m.txQ, 1) + assert.Len(t, m.txns, 1) + assert.Equal(t, m.txQ[0].Hash, tx2.Hash) + _, exists := m.txns[tx1.Hash] + assert.False(t, exists) + + // Test removing non-existent transaction + nonExistentHash := types.Hash{9} + m.Remove(nonExistentHash) + assert.Len(t, m.txQ, 1) + assert.Len(t, m.txns, 1) + assert.Equal(t, m.txQ[0].Hash, tx2.Hash) + + // Test removing last transaction + m.Remove(tx2.Hash) + assert.Empty(t, m.txQ) + assert.Empty(t, m.txns) +} + +func Test_MempoolReapN(t *testing.T) { + m := New() + + // Setup test transactions + tx1 := types.NamedTx{ + Hash: types.Hash{1, 2, 3}, + Tx: newTx(1, "A"), + } + tx2 := types.NamedTx{ + Hash: types.Hash{4, 5, 6}, + Tx: newTx(2, "B"), + } + tx3 := types.NamedTx{ + Hash: types.Hash{7, 8, 9}, + Tx: newTx(3, "C"), + } + + // Test reaping from empty mempool + emptyReap := m.ReapN(1) + assert.Empty(t, emptyReap) + + // Add transactions to mempool + m.Store(tx1.Hash, tx1.Tx) + m.Store(tx2.Hash, tx2.Tx) + m.Store(tx3.Hash, tx3.Tx) + + // Test reaping more transactions than available + overReap := m.ReapN(5) + assert.Len(t, overReap, 3) + assert.Equal(t, overReap[0].Hash, tx1.Hash) + assert.Equal(t, overReap[1].Hash, tx2.Hash) + assert.Equal(t, overReap[2].Hash, tx3.Hash) + assert.Empty(t, m.txQ) + assert.Empty(t, m.txns) + + // Refill mempool + m.Store(tx1.Hash, tx1.Tx) + m.Store(tx2.Hash, tx2.Tx) + m.Store(tx3.Hash, tx3.Tx) + + // Test partial reaping + partialReap := m.ReapN(2) + assert.Len(t, partialReap, 2) + assert.Equal(t, partialReap[0].Hash, tx1.Hash) + assert.Equal(t, partialReap[1].Hash, tx2.Hash) + assert.Len(t, m.txQ, 1) + assert.Len(t, m.txns, 1) + + // Test reaping remaining transaction + finalReap := m.ReapN(1) + assert.Len(t, finalReap, 1) + assert.Equal(t, finalReap[0].Hash, tx3.Hash) + assert.Empty(t, m.txQ) + assert.Empty(t, m.txns) + + // Test reaping with zero count + zeroReap := m.ReapN(0) + assert.Empty(t, zeroReap) +} diff --git a/node/node.go b/node/node.go index 109fb9651..7a0f31f84 100644 --- a/node/node.go +++ b/node/node.go @@ -471,14 +471,10 @@ func (n *Node) Status(ctx context.Context) (*adminTypes.Status, error) { } func (n *Node) TxQuery(ctx context.Context, hash types.Hash, prove bool) (*ktypes.TxQueryResponse, error) { - raw, height, blkHash, blkIdx, err := n.bki.GetTx(hash) + tx, height, blkHash, blkIdx, err := n.bki.GetTx(hash) if err != nil { return nil, err } - var tx ktypes.Transaction - if err = tx.UnmarshalBinary(raw); err != nil { - return nil, err - } blkResults, err := n.bki.Results(blkHash) if err != nil { return nil, err @@ -488,7 +484,7 @@ func (n *Node) TxQuery(ctx context.Context, hash types.Hash, prove bool) (*ktype } res := blkResults[blkIdx] return &ktypes.TxQueryResponse{ - Tx: &tx, + Tx: tx, Hash: hash, Height: height, Result: &res, @@ -499,12 +495,11 @@ func (n *Node) BroadcastTx(ctx context.Context, tx *ktypes.Transaction, _ /*sync rawTx, _ := tx.MarshalBinary() txHash := types.HashBytes(rawTx) - // TODO: checkTx before accepting the Tx - if err := n.ce.CheckTx(ctx, rawTx); err != nil { + if err := n.ce.CheckTx(ctx, tx); err != nil { return nil, err } - n.mp.Store(txHash, rawTx) + n.mp.Store(txHash, tx) n.log.Infof("broadcasting new tx %v", txHash) n.announceTx(ctx, txHash, rawTx, n.host.ID()) diff --git a/node/node_test.go b/node/node_test.go index 1f00ae30f..ade99a5ed 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -188,7 +188,7 @@ func (ce *dummyCE) Role() types.Role { return types.RoleLeader } -func (ce *dummyCE) CheckTx(ctx context.Context, tx []byte) error { +func (ce *dummyCE) CheckTx(ctx context.Context, tx *ktypes.Transaction) error { return nil } diff --git a/node/nogossip.go b/node/nogossip.go index ace2e0ee7..44eed40da 100644 --- a/node/nogossip.go +++ b/node/nogossip.go @@ -9,6 +9,7 @@ import ( "github.com/kwilteam/kwil-db/core/crypto" "github.com/kwilteam/kwil-db/core/crypto/auth" + ktypes "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/node/types" "github.com/libp2p/go-libp2p/core/network" @@ -60,6 +61,12 @@ func (n *Node) txAnnStreamHandler(s network.Stream) { } } + var tx ktypes.Transaction + if err = tx.UnmarshalBinary(rawTx); err != nil { + n.log.Errorf("invalid transaction received %v: %v", txHash, err) + return + } + // n.log.Infof("obtained content for tx %q in %v", txid, time.Since(t0)) // here we could check tx index again in case a block was mined with it @@ -67,10 +74,10 @@ func (n *Node) txAnnStreamHandler(s network.Stream) { // store in mempool since it was not in tx index and thus not confirmed ctx := context.Background() - if err := n.ce.CheckTx(ctx, rawTx); err != nil { + if err := n.ce.CheckTx(ctx, &tx); err != nil { n.log.Warnf("tx %v failed check: %v", txHash, err) } else { - n.mp.Store(txHash, rawTx) + n.mp.Store(txHash, &tx) fetched = true // re-announce @@ -220,7 +227,12 @@ func (n *Node) startTxAnns(ctx context.Context, reannouncePeriod time.Duration) n.log.Infof("re-announcing %d unconfirmed txns", len(txns)) for _, nt := range txns { - n.announceTx(ctx, nt.Hash, nt.Tx, n.host.ID()) // response handling is async + rawTx, err := nt.Tx.MarshalBinary() + if err != nil { + n.log.Errorf("Failed to marshal transaction %v: %v", nt.Hash, err) + continue + } + n.announceTx(ctx, nt.Hash, rawTx, n.host.ID()) // response handling is async if ctx.Err() != nil { n.log.Warn("interrupting long re-broadcast") break diff --git a/node/store/memstore/memstore.go b/node/store/memstore/memstore.go index f82edcab8..23349d267 100644 --- a/node/store/memstore/memstore.go +++ b/node/store/memstore/memstore.go @@ -158,7 +158,7 @@ func (bs *MemBS) PreFetch(blkid types.Hash) (bool, func()) { func (bs *MemBS) Close() error { return nil } -func (bs *MemBS) GetTx(txHash types.Hash) (raw []byte, height int64, hash types.Hash, idx uint32, err error) { +func (bs *MemBS) GetTx(txHash types.Hash) (tx *ktypes.Transaction, height int64, hash types.Hash, idx uint32, err error) { bs.mtx.RLock() defer bs.mtx.RUnlock() // check the tx index, pull the block and then search for the tx with the expected hash @@ -170,8 +170,12 @@ func (bs *MemBS) GetTx(txHash types.Hash) (raw []byte, height int64, hash types. if !have { return nil, 0, types.Hash{}, 0, types.ErrNotFound } - for idx, tx := range blk.Txns { - if types.HashBytes(tx) == txHash { + for idx, rawTx := range blk.Txns { + if types.HashBytes(rawTx) == txHash { + tx := new(ktypes.Transaction) + if err = tx.UnmarshalBinary(rawTx); err != nil { + return nil, 0, types.Hash{}, 0, err + } return tx, blk.Header.Height, blk.Hash(), uint32(idx), nil } } diff --git a/node/store/memstore/memstore_test.go b/node/store/memstore/memstore_test.go index 6e7d9636c..5cbd2804a 100644 --- a/node/store/memstore/memstore_test.go +++ b/node/store/memstore/memstore_test.go @@ -1,12 +1,14 @@ package memstore import ( + "bytes" "encoding/binary" - "strconv" + "math/big" "strings" "testing" "time" + "github.com/kwilteam/kwil-db/core/crypto/auth" ktypes "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/node/types" ) @@ -15,20 +17,40 @@ func fakeAppHash(height int64) types.Hash { return types.HashBytes(binary.LittleEndian.AppendUint64(nil, uint64(height))) } -func createTestBlock(height int64, numTxns int) (*ktypes.Block, types.Hash) { +func newTx(nonce uint64, sender string) *ktypes.Transaction { + return &ktypes.Transaction{ + Signature: &auth.Signature{}, + Body: &ktypes.TransactionBody{ + Description: "test", + Payload: []byte(`random payload`), + Fee: big.NewInt(0), + Nonce: nonce, + }, + Sender: []byte(sender), + } +} + +func createTestBlock(height int64, numTxns int) (*ktypes.Block, types.Hash, []*ktypes.Transaction) { + txs := make([]*ktypes.Transaction, numTxns) txns := make([][]byte, numTxns) for i := range numTxns { - txns[i] = []byte(strconv.FormatInt(height, 10) + strconv.Itoa(i) + - strings.Repeat("data", 1000)) + tx := newTx(uint64(i)+uint64(height), "sender") + tx.Body.Payload = []byte(strings.Repeat("data", 1000)) + rawTx, err := tx.MarshalBinary() + if err != nil { + panic(err) + } + txs[i] = tx + txns[i] = rawTx } return ktypes.NewBlock(height, types.Hash{2, 3, 4}, types.Hash{6, 7, 8}, types.Hash{5, 5, 5}, - time.Unix(1729723553+height, 0), txns), fakeAppHash(height) + time.Unix(1729723553+height, 0), txns), fakeAppHash(height), txs } func TestMemBS_StoreAndGet(t *testing.T) { bs := NewMemBS() - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) err := bs.Store(block, appHash) if err != nil { @@ -125,7 +147,7 @@ func TestMemBS_StoreAndGetTx(t *testing.T) { // tx2 := []byte("tx2") // txns := [][]byte{tx1, tx2} // block := types.NewBlock(1, prevHash, appHash, valSetHash, time.Unix(123456789, 0), txns) - block, _ := createTestBlock(1, 2) + block, _, _ := createTestBlock(1, 2) tx1 := block.Txns[0] if err := bs.Store(block, types.Hash{1, 2, 3}); err != nil { @@ -142,8 +164,12 @@ func TestMemBS_StoreAndGetTx(t *testing.T) { t.Errorf("got height %d, want 1", height) } - if string(gotTx) != string(tx1) { - t.Errorf("got tx %s, want %s", string(gotTx), string(tx1)) + gotRawTx, err := gotTx.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(gotRawTx, tx1) { + t.Errorf("got tx %x, want %x", gotRawTx, tx1) } if blkHash := block.Hash(); hash != blkHash { diff --git a/node/store/store.go b/node/store/store.go index ca4d9a252..16a5cb350 100644 --- a/node/store/store.go +++ b/node/store/store.go @@ -469,7 +469,8 @@ func (bki *BlockStore) HaveTx(txHash types.Hash) bool { // GetTx returns the raw bytes of the transaction, and information on the block // containing the transaction. -func (bki *BlockStore) GetTx(txHash types.Hash) (raw []byte, height int64, blkHash types.Hash, blkIdx uint32, err error) { +func (bki *BlockStore) GetTx(txHash types.Hash) (tx *ktypes.Transaction, height int64, blkHash types.Hash, blkIdx uint32, err error) { + var raw []byte err = bki.db.View(func(txn *badger.Txn) error { // Get block info from the tx index key := slices.Concat(nsTxn, txHash[:]) // tdb["t:txHash"] => blk info @@ -515,5 +516,13 @@ func (bki *BlockStore) GetTx(txHash types.Hash) (raw []byte, height int64, blkHa if errors.Is(err, badger.ErrKeyNotFound) { err = types.ErrNotFound } + + if len(raw) == 0 { + return + } + + tx = new(ktypes.Transaction) + err = tx.UnmarshalBinary(raw) + return } diff --git a/node/store/store_test.go b/node/store/store_test.go index 969d9304b..d3b2f6159 100644 --- a/node/store/store_test.go +++ b/node/store/store_test.go @@ -5,15 +5,16 @@ import ( crand "crypto/rand" "encoding/binary" "fmt" + "math/big" "math/rand/v2" "os" "path/filepath" - "strconv" "strings" "testing" "text/tabwriter" "time" + "github.com/kwilteam/kwil-db/core/crypto/auth" ktypes "github.com/kwilteam/kwil-db/core/types" "github.com/kwilteam/kwil-db/node/types" ) @@ -67,20 +68,27 @@ func fakeAppHash(height int64) types.Hash { return types.HashBytes(binary.LittleEndian.AppendUint64(nil, uint64(height))) } -func createTestBlock(height int64, numTxns int) (*ktypes.Block, types.Hash) { +func createTestBlock(height int64, numTxns int) (*ktypes.Block, types.Hash, []*ktypes.Transaction) { + txs := make([]*ktypes.Transaction, numTxns) txns := make([][]byte, numTxns) for i := range numTxns { - txns[i] = []byte(strconv.FormatInt(height, 10) + strconv.Itoa(i) + - strings.Repeat("data", 1000)) + tx := newTx(uint64(i)+uint64(height), "sender") + tx.Body.Payload = []byte(strings.Repeat("data", 1000)) + rawTx, err := tx.MarshalBinary() + if err != nil { + panic(err) + } + txs[i] = tx + txns[i] = rawTx } return ktypes.NewBlock(height, types.Hash{2, 3, 4}, types.Hash{6, 7, 8}, types.Hash{5, 5, 5}, - time.Unix(1729723553+height, 0), txns), fakeAppHash(height) + time.Unix(1729723553+height, 0), txns), fakeAppHash(height), txs } func TestBlockStore_StoreAndGet(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) err := bs.Store(block, appHash) if err != nil { t.Fatal(err) @@ -121,7 +129,7 @@ func TestBlockStore_StoreAndGet(t *testing.T) { func TestBlockStore_GetByHeight(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) bs.Store(block, appHash) gotHash, blk, gotAppHash, err := bs.GetByHeight(1) @@ -148,7 +156,7 @@ func TestBlockStore_GetByHeight(t *testing.T) { func TestBlockStore_Have(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) hash := block.Hash() if bs.Have(hash) { @@ -167,12 +175,12 @@ func TestBlockStore_Have(t *testing.T) { func TestBlockStore_GetTx(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 3) + block, appHash, _ := createTestBlock(1, 3) bs.Store(block, appHash) for i := range block.Txns { txHash := types.HashBytes(block.Txns[i]) - txData, height, _, _, err := bs.GetTx(txHash) + tx, height, _, _, err := bs.GetTx(txHash) if err != nil { t.Fatal(err) } @@ -181,6 +189,11 @@ func TestBlockStore_GetTx(t *testing.T) { t.Errorf("Expected tx height %d, got %d", block.Header.Height, height) } + txData, err := tx.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(txData, block.Txns[i]) { t.Error("Retrieved transaction data doesn't match original") } @@ -190,7 +203,7 @@ func TestBlockStore_GetTx(t *testing.T) { func TestBlockStore_HaveTx(t *testing.T) { bs, dir := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 6) + block, appHash, _ := createTestBlock(1, 6) txHash := types.HashBytes(block.Txns[0]) if bs.HaveTx(txHash) { @@ -273,7 +286,7 @@ func TestBlockStore_StoreConcurrent(t *testing.T) { for i := range 3 { go func(start int) { for j := range blockCount { - block, appHash := createTestBlock(int64(start*blockCount+j), 2) + block, appHash, _ := createTestBlock(int64(start*blockCount+j), 2) err := bs.Store(block, appHash) if err != nil { t.Error(err) @@ -306,7 +319,7 @@ func TestBlockStore_StoreConcurrent(t *testing.T) { func TestBlockStore_StoreDuplicateBlock(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) err := bs.Store(block, appHash) if err != nil { @@ -332,17 +345,30 @@ func TestBlockStore_StoreDuplicateBlock(t *testing.T) { func TestBlockStore_StoreWithLargeTransactions(t *testing.T) { bs, _ := setupTestBlockStore(t, true) - largeTx := make([]byte, 1<<20) // 1MB transaction - for i := range largeTx { - largeTx[i] = byte(i % 256) + largeTxPayload := make([]byte, 1<<20) // 1MB transaction + for i := range largeTxPayload { + largeTxPayload[i] = byte(i % 256) + } + + largeTx := newTx(2, "moo") + largeTx.Body.Payload = largeTxPayload + otherTx := newTx(1, "Adsf") + otherTx.Body.Payload = []byte{1, 2, 3} + + largeTxRaw, err := largeTx.MarshalBinary() + if err != nil { + t.Fatal(err) + } + otherTxRaw, err := otherTx.MarshalBinary() + if err != nil { + t.Fatal(err) } - otherTx := []byte{1, 2, 3} block := ktypes.NewBlock(1, types.Hash{2, 3, 4}, types.Hash{6, 7, 8}, types.Hash{}, - time.Unix(1729723553, 0), [][]byte{largeTx, otherTx}) + time.Unix(1729723553, 0), [][]byte{largeTxRaw, otherTxRaw}) appHash := fakeAppHash(1) - err := bs.Store(block, appHash) + err = bs.Store(block, appHash) if err != nil { t.Fatal(err) } @@ -359,13 +385,17 @@ func TestBlockStore_StoreWithLargeTransactions(t *testing.T) { t.Fatal("apphash mismatch") } - for _, tx := range [][]byte{largeTx, otherTx} { - txHash := types.HashBytes(tx) - txData, _, _, _, err := bs.GetTx(txHash) + for _, rawTx := range [][]byte{largeTxRaw, otherTxRaw} { + txHash := types.HashBytes(rawTx) + tx, _, _, _, err := bs.GetTx(txHash) if err != nil { t.Fatal(err) } - if !bytes.Equal(txData, tx) { + txData, err := tx.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(txData, rawTx) { t.Error("Retrieved transaction data doesn't match original") } } @@ -379,6 +409,19 @@ func TestBlockStore_StoreWithLargeTransactions(t *testing.T) { // } } +func newTx(nonce uint64, sender string) *ktypes.Transaction { + return &ktypes.Transaction{ + Signature: &auth.Signature{}, + Body: &ktypes.TransactionBody{ + Description: "test", + Payload: []byte(`random payload`), + Fee: big.NewInt(0), + Nonce: nonce, + }, + Sender: []byte(sender), + } +} + func TestLargeBlockStore(t *testing.T) { // This test demonstrates that zstd level 1 compression is faster than no // compression for reasonably compressible data. @@ -403,9 +446,9 @@ func TestLargeBlockStore(t *testing.T) { rng := rand.New(rngSrc) // Patterned tx body to make it compressible - txBody := make([]byte, txSize-8) - for i := range txBody { - txBody[i] = byte(i % 16) + txPayload := make([]byte, txSize-8) + for i := range txPayload { + txPayload[i] = byte(i % 16) } // Create blocks with random transactions @@ -413,10 +456,13 @@ func TestLargeBlockStore(t *testing.T) { // Generate random transactions txs := make([][]byte, txsPerBlock) for i := range txs { - tx := make([]byte, txSize) - rngSrc.Read(tx[:8]) // like a nonce, ensures txs are unique - copy(tx[8:], txBody) - txs[i] = tx + tx := newTx(uint64(i), "sendername") + tx.Body.Payload = make([]byte, txSize) + copy(tx.Body.Payload, txPayload) + txs[i], err = tx.MarshalBinary() + if err != nil { + t.Fatal(err) + } } // Create and store block @@ -449,14 +495,18 @@ func TestLargeBlockStore(t *testing.T) { txIdx := rng.IntN(len(txs)) txHash := types.HashBytes(txs[txIdx]) - gotTx, gotHeight, _, _, err := bs.GetTx(txHash) + tx, gotHeight, _, _, err := bs.GetTx(txHash) if err != nil { t.Errorf("Failed to get tx at height %d, idx %d: %v", height, txIdx, err) } + txData, err := tx.MarshalBinary() + if err != nil { + t.Fatal(err) + } if gotHeight != height { t.Errorf("Wrong tx height. Got %d, want %d", gotHeight, height) } - if !bytes.Equal(gotTx, txs[txIdx]) { + if !bytes.Equal(txData, txs[txIdx]) { t.Error("Retrieved tx data mismatch") } } @@ -492,7 +542,7 @@ func getDirSize(path string) int64 { func TestBlockStore_StoreAndGetResults(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 3) + block, appHash, _ := createTestBlock(1, 3) err := bs.Store(block, appHash) if err != nil { t.Fatal(err) @@ -569,7 +619,7 @@ func TestBlockStore_ResultsNonExistentBlock(t *testing.T) { func TestBlockStore_StoreResultsLargeData(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) err := bs.Store(block, appHash) if err != nil { t.Fatal(err) @@ -605,7 +655,7 @@ func TestBlockStore_StoreResultsLargeData(t *testing.T) { func TestBlockStore_StoreResultsMismatchedCount(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 2) + block, appHash, _ := createTestBlock(1, 2) err := bs.Store(block, appHash) if err != nil { t.Fatal(err) @@ -632,7 +682,7 @@ func TestBlockStore_StoreResultsMismatchedCount(t *testing.T) { func TestBlockStore_Result(t *testing.T) { bs, _ := setupTestBlockStore(t) - block, appHash := createTestBlock(1, 3) + block, appHash, _ := createTestBlock(1, 3) err := bs.Store(block, appHash) if err != nil { t.Fatal(err) diff --git a/node/tx.go b/node/tx.go index e03844dee..bc0dbc7e3 100644 --- a/node/tx.go +++ b/node/tx.go @@ -103,23 +103,23 @@ func (n *Node) txGetStreamHandler(s network.Stream) { } // first check mempool - rawTx := n.mp.Get(req.Hash) - if rawTx != nil { - s.Write(rawTx) + tx := n.mp.Get(req.Hash) + if tx != nil { + tx.WriteTo(s) return } // this is racy, and should be different in product // then confirmed tx index - rawTx, _, _, _, err := n.bki.GetTx(req.Hash) + tx, _, _, _, err := n.bki.GetTx(req.Hash) if err != nil { if !errors.Is(err, types.ErrNotFound) { n.log.Errorf("unexpected GetTx error: %v", err) } s.Write(noData) // don't have it } else { - s.Write(rawTx) + tx.WriteTo(s) } // NOTE: response could also include conf/unconf or block height (-1 or N) diff --git a/node/types/interfaces.go b/node/types/interfaces.go index f75445d50..9c63f76d7 100644 --- a/node/types/interfaces.go +++ b/node/types/interfaces.go @@ -1,8 +1,6 @@ package types import ( - "context" - "github.com/kwilteam/kwil-db/core/types" ) @@ -48,15 +46,16 @@ type BlockResultsStorer interface { } type TxGetter interface { - GetTx(txHash types.Hash) (raw []byte, height int64, blkHash types.Hash, blkIdx uint32, err error) + GetTx(txHash types.Hash) (raw *types.Transaction, height int64, blkHash types.Hash, blkIdx uint32, err error) HaveTx(Hash) bool } type MemPool interface { Size() int - ReapN(int) ([]Hash, [][]byte) // Reap(n int, maxBts int) ([]Hash, [][]byte) - Get(Hash) []byte - Store(Hash, []byte) + ReapN(int) []NamedTx + Get(Hash) *types.Transaction + Remove(Hash) + Store(Hash, *types.Transaction) PeekN(n int) []NamedTx // Check([]byte) PreFetch(txid Hash) bool // should be app level instead @@ -69,11 +68,7 @@ type QualifiedBlock struct { // basically just caches the hash AppHash *Hash } -type Execution interface { - ExecBlock(blk *types.Block) (commit func(context.Context, bool) error, appHash Hash, res []types.TxResult, err error) -} - type NamedTx struct { Hash Hash - Tx []byte + Tx *types.Transaction }