Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use structured transaction in store and mempool #1146

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions core/crypto/auth/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (
"fmt"
"io"

"github.com/kwilteam/kwil-db/core/crypto"
"golang.org/x/crypto/sha3"

"github.com/kwilteam/kwil-db/core/crypto"
"github.com/kwilteam/kwil-db/core/utils"
)

// Signature is a signature with a designated AuthType, which should
// be used to determine how to verify the signature.
// It seems a bit weird to have a field "Signature" inside a struct called "Signature",
// but I am keeping it like this for compatibility with the old code.
type Signature struct {
// Data is the raw signature bytes
Data []byte `json:"sig"`
Expand All @@ -23,25 +23,38 @@ type Signature struct {
Type string `json:"type"`
}

func (s Signature) MarshalBinary() ([]byte, error) {
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, uint32(len(s.Data))); err != nil {
return nil, fmt.Errorf("failed to write signature length: %w", err)
var _ io.WriterTo = Signature{}

func (s Signature) WriteTo(w io.Writer) (int64, error) {
cw := utils.NewCountingWriter(w)
if err := binary.Write(cw, binary.LittleEndian, uint32(len(s.Data))); err != nil {
return cw.Written(), fmt.Errorf("failed to write signature length: %w", err)
}
if err := binary.Write(buf, binary.LittleEndian, s.Data); err != nil {
return nil, fmt.Errorf("failed to write signature data: %w", err)
if err := binary.Write(cw, binary.LittleEndian, s.Data); err != nil {
return cw.Written(), fmt.Errorf("failed to write signature data: %w", err)
}

if err := binary.Write(buf, binary.LittleEndian, uint32(len(s.Type))); err != nil {
return nil, fmt.Errorf("failed to write signature type length: %w", err)
if err := binary.Write(cw, binary.LittleEndian, uint32(len(s.Type))); err != nil {
return cw.Written(), fmt.Errorf("failed to write signature type length: %w", err)
}
if err := binary.Write(buf, binary.LittleEndian, []byte(s.Type)); err != nil {
return nil, fmt.Errorf("failed to write signature type: %w", err)
if err := binary.Write(cw, binary.LittleEndian, []byte(s.Type)); err != nil {
return cw.Written(), fmt.Errorf("failed to write signature type: %w", err)
}

return cw.Written(), nil
}

func (s Signature) MarshalBinary() ([]byte, error) {
buf := new(bytes.Buffer)
s.WriteTo(buf) // does not error with a bytes.Buffer as the Writer
return buf.Bytes(), nil
}

func (s Signature) Bytes() []byte {
b, _ := s.MarshalBinary() // does not error
return b
}

func (s *Signature) UnmarshalBinary(data []byte) error {
r := bytes.NewReader(data)
n, err := s.ReadFrom(r)
Expand Down
4 changes: 2 additions & 2 deletions core/types/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ type BlockHeader struct {

type Block struct {
Header *BlockHeader
Txns [][]byte
Signature []byte // Signature is the block producer's signature (leader in our model)
Txns [][]byte // TODO: convert to []*Transaction
Signature []byte // Signature is the block producer's signature (leader in our model)
}

func (b *Block) Hash() Hash {
Expand Down
150 changes: 83 additions & 67 deletions core/types/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/kwilteam/kwil-db/core/crypto"
"github.com/kwilteam/kwil-db/core/crypto/auth"
"github.com/kwilteam/kwil-db/core/utils"
)

type ResultBroadcastTx struct {
Expand Down Expand Up @@ -48,12 +49,23 @@ type Transaction struct {
Sender HexBytes `json:"sender"`

strictUnmarshal bool
// cachedHash *Hash // maybe maybe maybe... this would require a mutex or careful use
}

func (t *Transaction) StrictUnmarshal() {
t.strictUnmarshal = true
}

// Hash gives the hash of the transaction that is the unique identifier for the
// transaction.
func (t *Transaction) Hash() (Hash, error) {
raw, err := t.MarshalBinary()
if err != nil {
return Hash{}, err
}
return HashBytes(raw), nil
}

// TransactionBody is the body of a transaction that gets included in the
// signature. This type implements json.Marshaler and json.Unmarshaler to ensure
// that the Fee field is represented as a string in JSON rather than a number.
Expand Down Expand Up @@ -259,6 +271,14 @@ func (t *TransactionBody) SerializeMsg(mst SignedMsgSerializationType) ([]byte,
return nil, errors.New("invalid serialization type")
}

var _ io.WriterTo = (*Transaction)(nil)

func (t *Transaction) WriteTo(w io.Writer) (int64, error) {
cw := utils.NewCountingWriter(w)
err := t.serialize(cw)
return cw.Written(), err
}

var _ encoding.BinaryMarshaler = (*Transaction)(nil)

// MarshalBinary produces the full binary serialization of the transaction,
Expand All @@ -276,9 +296,9 @@ var _ io.ReaderFrom = (*Transaction)(nil)
func (t *Transaction) ReadFrom(r io.Reader) (int64, error) {
n, err := t.deserialize(r)
if err != nil {
return int64(n), err
return n, err
}
return int64(n), nil
return n, nil
}

var _ encoding.BinaryUnmarshaler = (*Transaction)(nil)
Expand All @@ -292,7 +312,7 @@ func (t *Transaction) UnmarshalBinary(data []byte) error {
if !t.strictUnmarshal {
return nil
}
if n != len(data) {
if n != int64(len(data)) {
return errors.New("failed to read all")
}
if r.Len() != 0 {
Expand All @@ -301,96 +321,102 @@ func (t *Transaction) UnmarshalBinary(data []byte) error {
return nil
}

func (tb TransactionBody) Bytes() []byte {
b, _ := tb.MarshalBinary() // does not error
return b
}

var _ encoding.BinaryMarshaler = (*TransactionBody)(nil)

func (tb *TransactionBody) MarshalBinary() ([]byte, error) {
func (tb TransactionBody) MarshalBinary() ([]byte, error) {
buf := new(bytes.Buffer)
tb.WriteTo(buf) // no error with bytes.Buffer
return buf.Bytes(), nil
}

var _ io.WriterTo = TransactionBody{}

func (tb TransactionBody) WriteTo(w io.Writer) (int64, error) {
cw := utils.NewCountingWriter(w)
// Description Length + Description
if err := writeString(buf, tb.Description); err != nil {
return nil, fmt.Errorf("failed to write transaction body description: %w", err)
if err := writeString(cw, tb.Description); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction body description: %w", err)
}

// serialized Payload
if err := writeBytes(buf, tb.Payload); err != nil {
return nil, fmt.Errorf("failed to write transaction body payload: %w", err)
if err := writeBytes(cw, tb.Payload); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction body payload: %w", err)
}

// PayloadType
payloadType := tb.PayloadType.String()
if err := writeString(buf, payloadType); err != nil {
return nil, fmt.Errorf("failed to write transaction body payload type: %w", err)
if err := writeString(cw, payloadType); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction body payload type: %w", err)
}

// Fee (big.Int)
if err := writeBigInt(buf, tb.Fee); err != nil {
return nil, fmt.Errorf("failed to write transaction fee: %w", err)
if err := writeBigInt(cw, tb.Fee); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction fee: %w", err)
}

// Nonce
if err := binary.Write(buf, binary.LittleEndian, tb.Nonce); err != nil {
return nil, fmt.Errorf("failed to write transaction body nonce: %w", err)
if err := binary.Write(cw, binary.LittleEndian, tb.Nonce); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction body nonce: %w", err)
}

// ChainID
if err := writeString(buf, tb.ChainID); err != nil {
return nil, fmt.Errorf("failed to write transaction body chain ID: %w", err)
if err := writeString(cw, tb.ChainID); err != nil {
return cw.Written(), fmt.Errorf("failed to write transaction body chain ID: %w", err)
}

return buf.Bytes(), nil
return cw.Written(), nil
}

var _ io.ReaderFrom = (*TransactionBody)(nil)

func (tb *TransactionBody) ReadFrom(r io.Reader) (int64, error) {
var n int
cr := utils.NewCountingReader(r)

// Description Length + Description
desc, err := readString(r)
desc, err := readString(cr)
if err != nil {
return int64(n), fmt.Errorf("failed to read transaction body description: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body description: %w", err)
}
tb.Description = desc
n += 4 + len(desc)

// serialized Payload
payload, err := readBytes(r)
payload, err := readBytes(cr)
if err != nil {
return int64(n), fmt.Errorf("failed to read transaction body payload: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body payload: %w", err)
}
tb.Payload = payload
n += 4 + len(payload)

// PayloadType
payloadType, err := readString(r)
payloadType, err := readString(cr)
if err != nil {
return int64(n), fmt.Errorf("failed to read transaction body payload type: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body payload type: %w", err)
}
tb.PayloadType = PayloadType(payloadType)
n += 4 + len(payloadType)

// Fee (big.Int)
b, ni, err := readBigInt(r)
b, _, err := readBigInt(cr)
if err != nil {
return int64(n), fmt.Errorf("failed to read transaction body fee: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body fee: %w", err)
}
tb.Fee = b // may be nil
n += ni

// Nonce
if err := binary.Read(r, binary.LittleEndian, &tb.Nonce); err != nil {
return int64(n), fmt.Errorf("failed to read transaction body nonce: %w", err)
if err := binary.Read(cr, binary.LittleEndian, &tb.Nonce); err != nil {
return cr.ReadCount(), fmt.Errorf("failed to read transaction body nonce: %w", err)
}
n += 8

// ChainID
chainID, err := readString(r)
chainID, err := readString(cr)
if err != nil {
return int64(n), fmt.Errorf("failed to read transaction body chain ID: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body chain ID: %w", err)
}
tb.ChainID = chainID
n += 4 + len(chainID)

return int64(n), nil
return cr.ReadCount(), nil
}

var _ encoding.BinaryUnmarshaler = (*TransactionBody)(nil)
Expand All @@ -417,26 +443,21 @@ func (tb *TransactionBody) UnmarshalBinary(data []byte) error {
}

func (t *Transaction) serialize(w io.Writer) (err error) {
if t.Body == nil {
return errors.New("missing transaction body")
}

// Tx Signature
var sigBytes []byte
if t.Signature != nil {
if sigBytes, err = t.Signature.MarshalBinary(); err != nil {
return fmt.Errorf("failed to marshal transaction signature: %w", err)
}
sigBytes = t.Signature.Bytes()
}
if err := writeBytes(w, sigBytes); err != nil {
return fmt.Errorf("failed to write transaction signature: %w", err)
}

// Tx Body
if t.Body == nil {
return errors.New("missing transaction body")
}
txBodyBytes, err := t.Body.MarshalBinary()
if err != nil {
return fmt.Errorf("failed to marshal transaction body: %w", err)
}
if _, err := w.Write(txBodyBytes); err != nil {
if _, err := t.Body.WriteTo(w); err != nil {
return fmt.Errorf("failed to write transaction body: %w", err)
}
/*var txBodyBytes []byte
Expand All @@ -463,38 +484,35 @@ func (t *Transaction) serialize(w io.Writer) (err error) {
return nil
}

func (t *Transaction) deserialize(r io.Reader) (int, error) {
var n int
func (t *Transaction) deserialize(r io.Reader) (int64, error) {
cr := utils.NewCountingReader(r)

// Signature
sigBytes, err := readBytes(r)
sigBytes, err := readBytes(cr)
if err != nil {
return n, fmt.Errorf("failed to read transaction signature: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction signature: %w", err)
}
n += 4 + len(sigBytes)

if len(sigBytes) != 0 {
var signature auth.Signature
if err = signature.UnmarshalBinary(sigBytes); err != nil {
return 0, fmt.Errorf("failed to unmarshal transaction signature: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to unmarshal transaction signature: %w", err)
}
t.Signature = &signature
}

// TxBody
var body TransactionBody
bodyLen, err := body.ReadFrom(r)
_, err = body.ReadFrom(cr)
if err != nil {
return 0, fmt.Errorf("failed to read transaction body: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction body: %w", err)
}
t.Body = &body
n += int(bodyLen)
/* if we need to support nil body...
bodyBytes, err := readBytes(r)
bodyBytes, err := readBytes(cr)
if err != nil {
return 0, fmt.Errorf("failed to read transaction body: %w", err)
}
n += 4 + len(bodyBytes)
if len(bodyBytes) != 0 {
var body TransactionBody
body.StrictUnmarshal()
Expand All @@ -505,22 +523,20 @@ func (t *Transaction) deserialize(r io.Reader) (int, error) {
}*/

// SerializationType
serType, err := readString(r)
serType, err := readString(cr)
if err != nil {
return 0, fmt.Errorf("failed to read transaction serialization type: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction serialization type: %w", err)
}
n += 4 + len(serType)
t.Serialization = SignedMsgSerializationType(serType)

// Sender
senderBytes, err := readBytes(r)
senderBytes, err := readBytes(cr)
if err != nil {
return 0, fmt.Errorf("failed to read transaction sender: %w", err)
return cr.ReadCount(), fmt.Errorf("failed to read transaction sender: %w", err)
}
n += 4 + len(senderBytes)
t.Sender = senderBytes

return n, nil
return cr.ReadCount(), nil
}

func writeBytes(w io.Writer, data []byte) error {
Expand Down
Loading
Loading