diff --git a/src/Neo/Persistence/DataCache.cs b/src/Neo/Persistence/DataCache.cs index 016bed2ee2..92aa83a605 100644 --- a/src/Neo/Persistence/DataCache.cs +++ b/src/Neo/Persistence/DataCache.cs @@ -202,50 +202,58 @@ public void Delete(StorageKey key) /// /// Finds the entries starting with the specified prefix. /// - /// The prefix of the key. + /// The prefix of the key. /// The search direction. /// The entries found with the desired prefix. - public IEnumerable<(StorageKey Key, StorageItem Value)> Find(byte[]? key_prefix = null, SeekDirection direction = SeekDirection.Forward) + public IEnumerable<(StorageKey Key, StorageItem Value)> Find(byte[]? keyPrefix = null, SeekDirection direction = SeekDirection.Forward) { - var seek_prefix = key_prefix; + var seekPrefix = keyPrefix; if (direction == SeekDirection.Backward) { - if (key_prefix == null) + if (keyPrefix == null) { // Backwards seek for null prefix is not supported for now. - throw new ArgumentNullException(nameof(key_prefix)); + throw new ArgumentNullException(nameof(keyPrefix)); } - if (key_prefix.Length == 0) + if (keyPrefix.Length == 0) { // Backwards seek for zero prefix is not supported for now. - throw new ArgumentOutOfRangeException(nameof(key_prefix)); + throw new ArgumentOutOfRangeException(nameof(keyPrefix)); } - seek_prefix = null; - for (var i = key_prefix.Length - 1; i >= 0; i--) + + seekPrefix = null; + for (var i = keyPrefix.Length - 1; i >= 0; i--) { - if (key_prefix[i] < 0xff) + if (keyPrefix[i] < 0xff) { - seek_prefix = key_prefix.Take(i + 1).ToArray(); - // The next key after the key_prefix. - seek_prefix[i]++; + seekPrefix = keyPrefix.Take(i + 1).ToArray(); + seekPrefix[i]++; // The next key after the key_prefix. break; } } - if (seek_prefix == null) + + if (seekPrefix == null) { - throw new ArgumentException($"{nameof(key_prefix)} with all bytes being 0xff is not supported now"); + // This case is rare + seekPrefix = new byte[ApplicationEngine.MaxStorageKeySize + 1]; + Array.Fill(seekPrefix, (byte)0xff); } + } - return FindInternal(key_prefix, seek_prefix, direction); + + return FindInternal(keyPrefix, seekPrefix, direction); } - private IEnumerable<(StorageKey Key, StorageItem Value)> FindInternal(byte[]? key_prefix, byte[]? seek_prefix, SeekDirection direction) + private IEnumerable<(StorageKey Key, StorageItem Value)> FindInternal(byte[]? keyPrefix, byte[]? seekPrefix, SeekDirection direction) { - foreach (var (key, value) in Seek(seek_prefix, direction)) - if (key_prefix == null || key.ToArray().AsSpan().StartsWith(key_prefix)) + var prefixIsEmpty = keyPrefix == null || keyPrefix.Length == 0; + foreach (var (key, value) in Seek(seekPrefix, direction)) + { + if (prefixIsEmpty || key.ToArray().AsSpan().StartsWith(keyPrefix)) yield return (key, value); - else if (direction == SeekDirection.Forward || (seek_prefix == null || !key.ToArray().SequenceEqual(seek_prefix))) + else if (direction == SeekDirection.Forward || (seekPrefix == null || !key.ToArray().SequenceEqual(seekPrefix))) yield break; + } } /// @@ -261,10 +269,12 @@ public void Delete(StorageKey key) ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; foreach (var (key, value) in Seek(start, direction)) + { if (comparer.Compare(key.ToArray(), end) < 0) yield return (key, value); else yield break; + } } /// @@ -415,31 +425,24 @@ public StorageItem GetOrAdd(StorageKey key, Func factory) { IEnumerable<(byte[], StorageKey, StorageItem)> cached; HashSet cachedKeySet; - ByteArrayComparer comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; + var comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; lock (_dictionary) { cached = _dictionary .Where(p => p.Value.State != TrackState.Deleted && p.Value.State != TrackState.NotFound && (keyOrPrefix == null || comparer.Compare(p.Key.ToArray(), keyOrPrefix) >= 0)) - .Select(p => - ( - KeyBytes: p.Key.ToArray(), - p.Key, - p.Value.Item - )) + .Select(p => (KeyBytes: p.Key.ToArray(), p.Key, p.Value.Item)) .OrderBy(p => p.KeyBytes, comparer) .ToArray(); cachedKeySet = new HashSet(_dictionary.Keys); } + var uncached = SeekInternal(keyOrPrefix ?? Array.Empty(), direction) .Where(p => !cachedKeySet.Contains(p.Key)) - .Select(p => - ( - KeyBytes: p.Key.ToArray(), - p.Key, - p.Value - )); + .Select(p => (KeyBytes: p.Key.ToArray(), p.Key, p.Value)); + using var e1 = cached.GetEnumerator(); using var e2 = uncached.GetEnumerator(); + (byte[] KeyBytes, StorageKey Key, StorageItem Item) i1, i2; bool c1 = e1.MoveNext(); bool c2 = e2.MoveNext(); diff --git a/tests/Neo.UnitTests/Persistence/UT_DataCache.cs b/tests/Neo.UnitTests/Persistence/UT_DataCache.cs index 3c5e8e6e7d..1f5bd59e77 100644 --- a/tests/Neo.UnitTests/Persistence/UT_DataCache.cs +++ b/tests/Neo.UnitTests/Persistence/UT_DataCache.cs @@ -185,6 +185,7 @@ public void TestFind() // null and empty with the backwards direction -> miserably fails. Action action = () => myDataCache.Find(null, SeekDirection.Backward); action.Should().Throw(); + action = () => myDataCache.Find(new byte[] { }, SeekDirection.Backward); action.Should().Throw(); @@ -404,5 +405,48 @@ public void TestFindInvalid() items.Current.Key.Should().Be(key4); items.MoveNext().Should().Be(false); } + + [TestMethod] + public void TestFindEmptyPrefix() + { + using var store = new MemoryStore(); + using var dataCache = new SnapshotCache(store); + + var k1 = StorageKey.CreateSearchPrefix(-1, []); + var k2 = StorageKey.CreateSearchPrefix(-1, [0x01]); + var k3 = StorageKey.CreateSearchPrefix(-1, [0xff, 0x02]); + + dataCache.Add(k1, value1); + dataCache.Add(k2, value2); + dataCache.Add(k3, value3); + + var items = dataCache.Find().ToArray(); + items.Length.Should().Be(3); + items[0].Key.ToArray().Should().BeEquivalentTo(k1.ToArray()); + items[1].Key.ToArray().Should().BeEquivalentTo(k2.ToArray()); + items[2].Key.ToArray().Should().BeEquivalentTo(k3.ToArray()); + + items = dataCache.Find([0xff, 0xff, 0xff, 0xff, 0xff]).ToArray(); + items.Length.Should().Be(1); + items[0].Key.ToArray().Should().BeEquivalentTo(k3.ToArray()); + + // null and empty are not supported for backwards direction now. + Action action = () => myDataCache.Find(null, SeekDirection.Backward); + action.Should().Throw(); + + action = () => myDataCache.Find(new byte[] { }, SeekDirection.Backward); + action.Should().Throw(); + + items = dataCache.Find([0xff, 0xff, 0xff, 0xff, 0xff], SeekDirection.Backward).ToArray(); + items.Length.Should().Be(1); + items[0].Key.ToArray().Should().BeEquivalentTo(k3.ToArray()); + + items = dataCache.Find([0xff], SeekDirection.Backward).ToArray(); + items.Length.Should().Be(3); + items[0].Key.ToArray().Should().BeEquivalentTo(k3.ToArray()); + items[1].Key.ToArray().Should().BeEquivalentTo(k2.ToArray()); + items[2].Key.ToArray().Should().BeEquivalentTo(k1.ToArray()); + } } } +