-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store compressed vectors in dense ByteSequence for PQVectors #370
Changes from 1 commit
650b4d2
049d4dd
34bce38
e9092e1
0cf8746
5389e2f
4cfcbf1
dfde124
b959350
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,34 +28,34 @@ | |
|
||
import java.io.DataOutput; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Arrays; | ||
import java.util.Objects; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
public class PQVectors implements CompressedVectors { | ||
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); | ||
private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom | ||
final ProductQuantization pq; | ||
private final List<ByteSequence<?>> compressedVectors; | ||
private final ByteSequence<?>[] compressedDataChunks; | ||
private final int vectorCount; | ||
private final int vectorsPerChunk; | ||
|
||
/** | ||
* Initialize the PQVectors with an initial List of vectors. This list may be | ||
* mutated, but caller is responsible for thread safety issues when doing so. | ||
*/ | ||
public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors) | ||
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks) | ||
{ | ||
this.pq = pq; | ||
this.compressedVectors = compressedVectors; | ||
this(pq, compressedDataChunks, compressedDataChunks.length, 1); | ||
} | ||
|
||
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors) | ||
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk) | ||
{ | ||
this(pq, List.of(compressedVectors)); | ||
this.pq = pq; | ||
this.compressedDataChunks = compressedDataChunks; | ||
this.vectorCount = vectorCount; | ||
this.vectorsPerChunk = vectorsPerChunk; | ||
} | ||
|
||
@Override | ||
public int count() { | ||
return compressedVectors.size(); | ||
return vectorCount; | ||
} | ||
|
||
@Override | ||
|
@@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException | |
pq.write(out, version); | ||
|
||
// compressed vectors | ||
out.writeInt(compressedVectors.size()); | ||
out.writeInt(vectorCount); | ||
out.writeInt(pq.getSubspaceCount()); | ||
for (var v : compressedVectors) { | ||
vectorTypeSupport.writeByteSequence(out, v); | ||
for (ByteSequence<?> chunk : compressedDataChunks) { | ||
vectorTypeSupport.writeByteSequence(out, chunk); | ||
} | ||
} | ||
|
||
|
@@ -77,44 +77,76 @@ public static PQVectors load(RandomAccessReader in) throws IOException { | |
var pq = ProductQuantization.load(in); | ||
|
||
// read the vectors | ||
int size = in.readInt(); | ||
if (size < 0) { | ||
throw new IOException("Invalid compressed vector count " + size); | ||
int vectorCount = in.readInt(); | ||
if (vectorCount < 0) { | ||
throw new IOException("Invalid compressed vector count " + vectorCount); | ||
} | ||
List<ByteSequence<?>> compressedVectors = new ArrayList<>(size); | ||
|
||
int compressedDimension = in.readInt(); | ||
if (compressedDimension < 0) { | ||
throw new IOException("Invalid compressed vector dimension " + compressedDimension); | ||
} | ||
|
||
for (int i = 0; i < size; i++) | ||
{ | ||
ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension); | ||
compressedVectors.add(vector); | ||
// Calculate if we need to split into multiple chunks | ||
long totalSize = (long) vectorCount * compressedDimension; | ||
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension; | ||
|
||
int numChunks = vectorCount / vectorsPerChunk; | ||
ByteSequence<?>[] chunks = new ByteSequence<?>[numChunks]; | ||
|
||
for (int i = 0; i < numChunks - 1; i++) { | ||
int chunkSize = vectorsPerChunk * compressedDimension; | ||
chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize); | ||
} | ||
|
||
return new PQVectors(pq, compressedVectors); | ||
// Last chunk might be smaller | ||
int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1)); | ||
chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension); | ||
|
||
return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk); | ||
} | ||
|
||
public static PQVectors load(RandomAccessReader in, long offset) throws IOException { | ||
in.seek(offset); | ||
return load(in); | ||
} | ||
|
||
/** | ||
* We consider two PQVectors equal when their PQs are equal and their compressed data is equal. We ignore the | ||
* chunking strategy in the comparison since this is an implementation detail. | ||
* @param o the object to check for equality | ||
* @return true if the objects are equal, false otherwise | ||
*/ | ||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
|
||
PQVectors that = (PQVectors) o; | ||
if (!Objects.equals(pq, that.pq)) return false; | ||
return Objects.equals(compressedVectors, that.compressedVectors); | ||
if (this.count() != that.count()) return false; | ||
// TODO how do we want to determine equality? With the current change, we are willing to write one | ||
// thing and materialize another. It seems like the real concern should be whether the compressedVectors have | ||
// the same data, not whether they are in a MemorySegment or a byte[] and not whether there is one chunk of many | ||
// vectors or many chunks of one vector. This technically goes against the implementation of each of the | ||
// ByteSequence#equals methods, which raises the question of whether this is valid. I primarily updated this | ||
// code to get testSaveLoadPQ to pass. | ||
for (int i = 0; i < this.count(); i++) { | ||
var thisNode = this.get(i); | ||
var thatNode = that.get(i); | ||
if (thisNode.length() != thatNode.length()) return false; | ||
for (int j = 0; j < thisNode.length(); j++) { | ||
if (thisNode.get(j) != thatNode.get(j)) return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(pq, compressedVectors); | ||
// We don't use the array structure in the hash code calculation because we allow for different chunking | ||
// strategies. Instead, we use the first entry in the first chunk to provide a stable hash code. | ||
return Objects.hash(pq, count(), count() > 0 ? get(0).get(0) : 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're worried about the performance of hashing millions of vectors, I'd prefer adding the first codepoint from each vector, to using all the first vector's codepoints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ended up just hashing based on every value since that is the norm. If it becomes a performance bottleneck, we can address it then--I doubt we compute the hash code frequently. |
||
} | ||
|
||
@Override | ||
|
@@ -188,7 +220,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q, | |
} | ||
|
||
public ByteSequence<?> get(int ordinal) { | ||
return compressedVectors.get(ordinal); | ||
int chunkIndex = ordinal / vectorsPerChunk; | ||
int vectorIndexInChunk = ordinal % vectorsPerChunk; | ||
int start = vectorIndexInChunk * pq.getSubspaceCount(); | ||
return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount()); | ||
} | ||
|
||
public ProductQuantization getProductQuantization() { | ||
|
@@ -225,16 +260,19 @@ public long ramBytesUsed() { | |
int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; | ||
|
||
long codebooksSize = pq.ramBytesUsed(); | ||
long listSize = (long) REF_BYTES * (1 + compressedVectors.size()); | ||
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size(); | ||
return codebooksSize + listSize + dataSize; | ||
long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES; | ||
long dataSize = 0; | ||
for (ByteSequence<?> chunk : compressedDataChunks) { | ||
dataSize += chunk.ramBytesUsed(); | ||
} | ||
return codebooksSize + chunksArraySize + dataSize; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "PQVectors{" + | ||
"pq=" + pq + | ||
", count=" + compressedVectors.size() + | ||
", count=" + vectorCount + | ||
'}'; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Copyright DataStax, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.github.jbellis.jvector.vector; | ||
|
||
import io.github.jbellis.jvector.util.RamUsageEstimator; | ||
import io.github.jbellis.jvector.vector.types.ByteSequence; | ||
import java.util.Arrays; | ||
|
||
/** | ||
* A read only {@link ByteSequence} implementation that wraps an array and provides a view into a slice of it. | ||
*/ | ||
public class ArraySliceByteSequence implements ByteSequence<byte[]> { | ||
private final byte[] data; | ||
private final int offset; | ||
private final int length; | ||
|
||
public ArraySliceByteSequence(byte[] data, int offset, int length) { | ||
if (offset < 0 || length < 0 || offset + length > data.length) { | ||
throw new IllegalArgumentException("Invalid offset or length"); | ||
} | ||
this.data = data; | ||
this.offset = offset; | ||
this.length = length; | ||
} | ||
|
||
@Override | ||
public byte[] get() { | ||
return data; | ||
} | ||
|
||
@Override | ||
public int offset() { | ||
return offset; | ||
} | ||
|
||
@Override | ||
public byte get(int n) { | ||
if (n < 0 || n >= length) { | ||
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); | ||
} | ||
return data[offset + n]; | ||
} | ||
|
||
@Override | ||
public void set(int n, byte value) { | ||
if (n < 0 || n >= length) { | ||
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); | ||
} | ||
data[offset + n] = value; | ||
} | ||
|
||
@Override | ||
public void setLittleEndianShort(int shortIndex, short value) { | ||
throw new UnsupportedOperationException("Not supported on slices"); | ||
} | ||
|
||
@Override | ||
public void zero() { | ||
throw new UnsupportedOperationException("Not supported on slices"); | ||
} | ||
|
||
@Override | ||
public int length() { | ||
return length; | ||
} | ||
|
||
@Override | ||
public ByteSequence<byte[]> copy() { | ||
byte[] newData = Arrays.copyOfRange(data, offset, offset + length); | ||
return new ArrayByteSequence(newData); | ||
} | ||
|
||
@Override | ||
public ByteSequence<byte[]> slice(int sliceOffset, int sliceLength) { | ||
if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) { | ||
throw new IllegalArgumentException("Invalid slice parameters"); | ||
} | ||
if (sliceOffset == 0 && sliceLength == length) { | ||
return this; | ||
} | ||
return new ArraySliceByteSequence(data, offset + sliceOffset, sliceLength); | ||
} | ||
|
||
@Override | ||
public long ramBytesUsed() { | ||
// Only count the overhead of this slice object, not the underlying array | ||
// since that's shared and counted elsewhere | ||
return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + | ||
(3 * Integer.BYTES); // offset, length, and reference to data | ||
} | ||
|
||
@Override | ||
public void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int copyLength) { | ||
throw new UnsupportedOperationException("Not supported on slices"); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
StringBuilder sb = new StringBuilder(); | ||
sb.append("["); | ||
for (int i = 0; i < Math.min(length, 25); i++) { | ||
sb.append(get(i)); | ||
if (i < length - 1) { | ||
sb.append(", "); | ||
} | ||
} | ||
if (length > 25) { | ||
sb.append("..."); | ||
} | ||
sb.append("]"); | ||
return sb.toString(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
ArraySliceByteSequence that = (ArraySliceByteSequence) o; | ||
if (this.length != that.length) return false; | ||
for (int i = 0; i < length; i++) { | ||
if (this.get(i) != that.get(i)) return false; | ||
} | ||
return true; | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
int result = 1; | ||
for (int i = 0; i < length; i++) { | ||
result = 31 * result + get(i); | ||
} | ||
return result; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with changing the ByteSequence equals implementations to match. You get the class comparison by default when intellij generates it for you and I think we just ran with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented this, but it's not as elegant as I wanted. I could have made the ByteSequence interface an abstract class and overridden equals/hashCode, but that seemed like it went too far, so I added default methods that implementations of the interface can call. In practice, I doubt these objects are compared for strict equality frequently.