Skip to content

Commit

Permalink
refactor(contracts/ScrollChain): stack too deep; fix ChunkCode.lastAp…
Browse files Browse the repository at this point in the history
…pliedL1BlockInBlock to return uint64 instead of uint256
  • Loading branch information
failfmi committed Nov 30, 2023
1 parent e99e26f commit 6507e1d
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 94 deletions.
234 changes: 141 additions & 93 deletions contracts/src/L1/rollup/ScrollChain.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
/// @notice The address of L1ViewOracle.
address public l1ViewOracle;

// stack too deep
struct CommitChunksResult {
bytes32 dataHash;
uint256 totalL1MessagesPoppedOverall;
uint256 totalL1MessagesPoppedInBatch;
uint64 lastAppliedL1Block;
bytes32 l1BlockRangeHashInBatch;
}

// stack too deep
struct ChunkResult {
// _totalNumL1MessagesInChunk The total number of L1 messages popped in current chunk
uint256 _totalNumL1MessagesInChunk;
// _lastAppliedL1BlockInChunk The last applied L1 Block Number in current chunk
uint64 _lastAppliedL1BlockInChunk;
// _l1BlockRangeHashInChunk The keccak256 of all the l1 block range hashes in current chunk
bytes32 _l1BlockRangeHashInChunk;
}

/**********************
* Function Modifiers *
**********************/
Expand Down Expand Up @@ -177,8 +196,7 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
require(_version == 0, "invalid version");

// check whether the batch is empty
uint256 _chunksLength = _chunks.length;
require(_chunksLength > 0, "batch is empty");
require(_chunks.length > 0, "batch is empty");

// The overall memory layout in this function is organized as follows
// +---------------------+-------------------+------------------+
Expand All @@ -202,87 +220,39 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
require(committedBatches[_batchIndex] == _parentBatchHash, "incorrect parent batch hash");
require(committedBatches[_batchIndex + 1] == 0, "batch already committed");

// load `dataPtr` and reserve the memory region for chunk data hashes
uint256 dataPtr;
assembly {
dataPtr := mload(0x40)
mstore(0x40, add(dataPtr, mul(_chunksLength, 32)))
}

uint64 _lastAppliedL1Block;
uint256 _totalNumL1MessagesInChunk;
uint64 _lastAppliedL1BlockInChunk;
bytes32 _l1BlockRangeHashInChunk;

// compute the data hash for each chunk
uint256 _totalL1MessagesPoppedInBatch;
bytes32[] memory _l1BlockRangeHashes = new bytes32[](_chunksLength);
for (uint256 i = 0; i < _chunksLength; i++) {
(_totalNumL1MessagesInChunk, _lastAppliedL1BlockInChunk, _l1BlockRangeHashInChunk) = _commitChunk(
dataPtr,
_chunks[i],
_totalL1MessagesPoppedInBatch,
_totalL1MessagesPoppedOverall,
_skippedL1MessageBitmap
);

if (_prevLastAppliedL1Block != 0) {
bytes32 _l1BlockRangeHash = IL1ViewOracle(l1ViewOracle).blockRangeHash(
_prevLastAppliedL1Block + 1,
_lastAppliedL1BlockInChunk
);

require(_l1BlockRangeHash == _l1BlockRangeHashInChunk, "incorrect l1 block range hash");
_l1BlockRangeHashes[i] = _l1BlockRangeHashInChunk;
_prevLastAppliedL1Block = _lastAppliedL1BlockInChunk;
}

// if it is the last chunk, update the last applied L1 block
if (i == _chunksLength - 1) {
_lastAppliedL1Block = _lastAppliedL1BlockInChunk;
}

unchecked {
_totalL1MessagesPoppedInBatch += _totalNumL1MessagesInChunk;
_totalL1MessagesPoppedOverall += _totalNumL1MessagesInChunk;
dataPtr += 32;
}
}

// check the length of bitmap
unchecked {
require(
((_totalL1MessagesPoppedInBatch + 255) / 256) * 32 == _skippedL1MessageBitmap.length,
"wrong bitmap length"
);
}
CommitChunksResult memory chunksResult = _commitChunks(
_chunks,
_totalL1MessagesPoppedOverall,
_skippedL1MessageBitmap,
_prevLastAppliedL1Block
);

// compute the data hash for current batch
bytes32 _dataHash;
assembly {
let dataLen := mul(_chunksLength, 0x20)
_dataHash := keccak256(sub(dataPtr, dataLen), dataLen)

batchPtr := mload(0x40) // reset batchPtr
_batchIndex := add(_batchIndex, 1) // increase batch index
}

bytes32 _l1BlockRangeHashInBatch = keccak256(abi.encodePacked(_l1BlockRangeHashes));
uint256 _skippedL1MessageBitmapLength = _skippedL1MessageBitmap.length;

// store entries, the order matters
BatchHeaderV0Codec.storeVersion(batchPtr, _version);
BatchHeaderV0Codec.storeBatchIndex(batchPtr, _batchIndex);
BatchHeaderV0Codec.storeL1MessagePopped(batchPtr, _totalL1MessagesPoppedInBatch);
BatchHeaderV0Codec.storeTotalL1MessagePopped(batchPtr, _totalL1MessagesPoppedOverall);
BatchHeaderV0Codec.storeDataHash(batchPtr, _dataHash);
BatchHeaderV0Codec.storeL1MessagePopped(batchPtr, chunksResult.totalL1MessagesPoppedInBatch);
BatchHeaderV0Codec.storeTotalL1MessagePopped(batchPtr, chunksResult.totalL1MessagesPoppedOverall);
BatchHeaderV0Codec.storeDataHash(batchPtr, chunksResult.dataHash);
BatchHeaderV0Codec.storeParentBatchHash(batchPtr, _parentBatchHash);
BatchHeaderV0Codec.storeSkippedBitmap(batchPtr, _skippedL1MessageBitmap);
BatchHeaderV0Codec.storeLastAppliedL1Block(batchPtr, _skippedL1MessageBitmapLength, _lastAppliedL1Block);
BatchHeaderV0Codec.storeL1BlockRangeHash(batchPtr, _skippedL1MessageBitmapLength, _l1BlockRangeHashInBatch);
BatchHeaderV0Codec.storeLastAppliedL1Block(
batchPtr,
_skippedL1MessageBitmap.length,
chunksResult.lastAppliedL1Block
);
BatchHeaderV0Codec.storeL1BlockRangeHash(
batchPtr,
_skippedL1MessageBitmap.length,
chunksResult.l1BlockRangeHashInBatch
);

// compute batch hash
bytes32 _batchHash = BatchHeaderV0Codec.computeBatchHash(batchPtr, 129 + _skippedL1MessageBitmapLength);
bytes32 _batchHash = BatchHeaderV0Codec.computeBatchHash(batchPtr, 129 + _skippedL1MessageBitmap.length);

committedBatches[_batchIndex] = _batchHash;
emit CommitBatch(_batchIndex, _batchHash);
Expand Down Expand Up @@ -486,28 +456,96 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
_batchHash = BatchHeaderV0Codec.computeBatchHash(memPtr, _length);
}

function _commitChunks(
bytes[] memory _chunks,
uint256 _totalL1MessagesPoppedOverall,
bytes calldata _skippedL1MessageBitmap,
uint64 _prevLastAppliedL1Block
) internal view returns (CommitChunksResult memory) {
uint256 _chunksLength = _chunks.length;
// load `dataPtr` and reserve the memory region for chunk data hashes
uint256 dataPtr;
assembly {
dataPtr := mload(0x40)
mstore(0x40, add(dataPtr, mul(_chunksLength, 32)))
}

uint256 _totalL1MessagesPoppedInBatch;
uint64 _lastAppliedL1Block;
bytes32[] memory _l1BlockRangeHashes = new bytes32[](_chunksLength);

for (uint256 i = 0; i < _chunksLength; i++) {
ChunkResult memory chunkResult = _commitChunk(
dataPtr,
_chunks[i],
_totalL1MessagesPoppedInBatch,
_totalL1MessagesPoppedOverall,
_skippedL1MessageBitmap
);

if (_prevLastAppliedL1Block != 0) {
bytes32 _l1BlockRangeHash = IL1ViewOracle(l1ViewOracle).blockRangeHash(
_prevLastAppliedL1Block + 1,
chunkResult._lastAppliedL1BlockInChunk
);

require(_l1BlockRangeHash == chunkResult._l1BlockRangeHashInChunk, "incorrect l1 block range hash");
_l1BlockRangeHashes[i] = chunkResult._l1BlockRangeHashInChunk;
_prevLastAppliedL1Block = chunkResult._lastAppliedL1BlockInChunk;
}

// if it is the last chunk, update the last applied L1 block
if (i == _chunksLength - 1) {
_lastAppliedL1Block = chunkResult._lastAppliedL1BlockInChunk;
}

unchecked {
_totalL1MessagesPoppedInBatch += chunkResult._totalNumL1MessagesInChunk;
_totalL1MessagesPoppedOverall += chunkResult._totalNumL1MessagesInChunk;
dataPtr += 32;
}
}

// check the length of bitmap
unchecked {
require(
((_totalL1MessagesPoppedInBatch + 255) / 256) * 32 == _skippedL1MessageBitmap.length,
"wrong bitmap length"
);
}

// compute the data hash for current batch
bytes32 _dataHash;
assembly {
let dataLen := mul(_chunksLength, 0x20)
_dataHash := keccak256(sub(dataPtr, dataLen), dataLen)
}

bytes32 _l1BlockRangeHashInBatch = keccak256(abi.encodePacked(_l1BlockRangeHashes));

return
CommitChunksResult({
dataHash: _dataHash,
totalL1MessagesPoppedOverall: _totalL1MessagesPoppedOverall,
totalL1MessagesPoppedInBatch: _totalL1MessagesPoppedInBatch,
lastAppliedL1Block: _lastAppliedL1Block,
l1BlockRangeHashInBatch: _l1BlockRangeHashInBatch
});
}

/// @dev Internal function to commit a chunk.
/// @param memPtr The start memory offset to store list of `dataHash`.
/// @param _chunk The encoded chunk to commit.
/// @param _totalL1MessagesPoppedInBatch The total number of L1 messages popped in current batch.
/// @param _totalL1MessagesPoppedOverall The total number of L1 messages popped in all batches including current batch.
/// @param _skippedL1MessageBitmap The bitmap indicates whether each L1 message is skipped or not.
/// @return _totalNumL1MessagesInChunk The total number of L1 message popped in current chunk
function _commitChunk(
uint256 memPtr,
bytes memory _chunk,
uint256 _totalL1MessagesPoppedInBatch,
uint256 _totalL1MessagesPoppedOverall,
bytes calldata _skippedL1MessageBitmap
)
internal
view
returns (
uint256 _totalNumL1MessagesInChunk,
uint64 _lastAppliedL1BlockInChunk,
bytes32 _l1BlockRangeHashInChunk
)
{
) internal view returns (ChunkResult memory chunkResult) {
uint256 chunkPtr;
uint256 startDataPtr;
uint256 dataPtr;
Expand Down Expand Up @@ -545,7 +583,6 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
blockPtr := add(chunkPtr, 1) // reset block ptr
}

uint256 _lastAppliedL1Block;
// concatenate tx hashes
uint256 l2TxPtr = ChunkCodec.l2TxPtr(chunkPtr, _numBlocks);
while (_numBlocks > 0) {
Expand Down Expand Up @@ -573,11 +610,11 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {

if (_numBlocks == 1) {
// check last block
_lastAppliedL1Block = ChunkCodec.lastAppliedL1BlockInBlock(blockPtr);
chunkResult._lastAppliedL1BlockInChunk = ChunkCodec.lastAppliedL1BlockInBlock(blockPtr);
}

unchecked {
_totalNumL1MessagesInChunk += _numL1MessagesInBlock;
chunkResult._totalNumL1MessagesInChunk += _numL1MessagesInBlock;
_totalL1MessagesPoppedInBatch += _numL1MessagesInBlock;
_totalL1MessagesPoppedOverall += _numL1MessagesInBlock;

Expand All @@ -586,31 +623,42 @@ contract ScrollChain is OwnableUpgradeable, PausableUpgradeable, IScrollChain {
}
}

_lastAppliedL1BlockInChunk = ChunkCodec.lastAppliedL1BlockInChunk(l2TxPtr);
_l1BlockRangeHashInChunk = ChunkCodec.l1BlockRangeHashInChunk(l2TxPtr);
// stack too deep
{
uint64 lastAppliedL1BlockInChunk = ChunkCodec.lastAppliedL1BlockInChunk(l2TxPtr);
chunkResult._l1BlockRangeHashInChunk = ChunkCodec.l1BlockRangeHashInChunk(l2TxPtr);

require(_lastAppliedL1Block == _lastAppliedL1BlockInChunk, "incorrect lastAppliedL1Block in chunk");
require(
lastAppliedL1BlockInChunk == chunkResult._lastAppliedL1BlockInChunk,
"incorrect lastAppliedL1Block in chunk"
);
}

// check the actual number of transactions in the chunk
require((dataPtr - txHashStartDataPtr) / 32 <= maxNumTxInChunk, "too many txs in one chunk");

assembly {
mstore(dataPtr, _lastAppliedL1BlockInChunk)
mstore(dataPtr, _l1BlockRangeHashInChunk)
dataPtr := add(dataPtr, 0x28)
}

// check chunk has correct length.
// 40 is the size of lastAppliedL1Block and l1BlockRangeHash.
require(l2TxPtr - chunkPtr + 40 == _chunk.length, "incomplete l2 transaction data");

// stack too deep
{
uint256 _lastAppliedL1BlockInChunk = chunkResult._lastAppliedL1BlockInChunk;
bytes32 _l1BlockRangeHashInChunk = chunkResult._l1BlockRangeHashInChunk;
assembly {
mstore(dataPtr, _lastAppliedL1BlockInChunk)
mstore(dataPtr, _l1BlockRangeHashInChunk)
dataPtr := add(dataPtr, 0x28)
}
}

// compute data hash and store to memory
assembly {
let dataHash := keccak256(startDataPtr, sub(dataPtr, startDataPtr))
mstore(memPtr, dataHash)
}

return (_totalNumL1MessagesInChunk, _lastAppliedL1BlockInChunk, _l1BlockRangeHashInChunk);
return chunkResult;
}

/// @dev Internal function to load L1 message hashes from the message queue.
Expand Down
2 changes: 1 addition & 1 deletion contracts/src/libraries/codec/ChunkCodec.sol
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ library ChunkCodec {
/// @notice Return the number of last applied L1 block.
/// @param blockPtr The start memory offset of the block context in memory.
/// @return _lastAppliedL1Block The number of last applied L1 block.
function lastAppliedL1BlockInBlock(uint256 blockPtr) internal pure returns (uint256 _lastAppliedL1Block) {
function lastAppliedL1BlockInBlock(uint256 blockPtr) internal pure returns (uint64 _lastAppliedL1Block) {
assembly {
_lastAppliedL1Block := shr(240, mload(add(blockPtr, 60)))
}
Expand Down

0 comments on commit 6507e1d

Please sign in to comment.