Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store compressed vectors in dense ByteSequence for PQVectors #370

Merged
merged 9 commits into from
Dec 2, 2024
102 changes: 70 additions & 32 deletions jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,34 @@

import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class PQVectors implements CompressedVectors {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom
final ProductQuantization pq;
private final List<ByteSequence<?>> compressedVectors;
private final ByteSequence<?>[] compressedDataChunks;
private final int vectorCount;
private final int vectorsPerChunk;

/**
* Initialize the PQVectors with an initial List of vectors. This list may be
* mutated, but caller is responsible for thread safety issues when doing so.
*/
public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks)
{
this.pq = pq;
this.compressedVectors = compressedVectors;
this(pq, compressedDataChunks, compressedDataChunks.length, 1);
}

public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk)
{
this(pq, List.of(compressedVectors));
this.pq = pq;
this.compressedDataChunks = compressedDataChunks;
this.vectorCount = vectorCount;
this.vectorsPerChunk = vectorsPerChunk;
}

@Override
public int count() {
return compressedVectors.size();
return vectorCount;
}

@Override
Expand All @@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException
pq.write(out, version);

// compressed vectors
out.writeInt(compressedVectors.size());
out.writeInt(vectorCount);
out.writeInt(pq.getSubspaceCount());
for (var v : compressedVectors) {
vectorTypeSupport.writeByteSequence(out, v);
for (ByteSequence<?> chunk : compressedDataChunks) {
vectorTypeSupport.writeByteSequence(out, chunk);
}
}

Expand All @@ -77,44 +77,76 @@ public static PQVectors load(RandomAccessReader in) throws IOException {
var pq = ProductQuantization.load(in);

// read the vectors
int size = in.readInt();
if (size < 0) {
throw new IOException("Invalid compressed vector count " + size);
int vectorCount = in.readInt();
if (vectorCount < 0) {
throw new IOException("Invalid compressed vector count " + vectorCount);
}
List<ByteSequence<?>> compressedVectors = new ArrayList<>(size);

int compressedDimension = in.readInt();
if (compressedDimension < 0) {
throw new IOException("Invalid compressed vector dimension " + compressedDimension);
}

for (int i = 0; i < size; i++)
{
ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension);
compressedVectors.add(vector);
// Calculate if we need to split into multiple chunks
long totalSize = (long) vectorCount * compressedDimension;
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension;

int numChunks = vectorCount / vectorsPerChunk;
ByteSequence<?>[] chunks = new ByteSequence<?>[numChunks];

for (int i = 0; i < numChunks - 1; i++) {
int chunkSize = vectorsPerChunk * compressedDimension;
chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize);
}

return new PQVectors(pq, compressedVectors);
// Last chunk might be smaller
int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1));
chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);

return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk);
}

public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
in.seek(offset);
return load(in);
}

/**
* We consider two PQVectors equal when their PQs are equal and their compressed data is equal. We ignore the
* chunking strategy in the comparison since this is an implementation detail.
* @param o the object to check for equality
* @return true if the objects are equal, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

PQVectors that = (PQVectors) o;
if (!Objects.equals(pq, that.pq)) return false;
return Objects.equals(compressedVectors, that.compressedVectors);
if (this.count() != that.count()) return false;
// TODO how do we want to determine equality? With the current change, we are willing to write one
// thing and materialize another. It seems like the real concern should be whether the compressedVectors have
// the same data, not whether they are in a MemorySegment or a byte[] and not whether there is one chunk of many
// vectors or many chunks of one vector. This technically goes against the implementation of each of the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with changing the ByteSequence equals implementations to match. You get the class comparison by default when intellij generates it for you and I think we just ran with that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this, but it's not as elegant as I wanted. I could have made the ByteSequence interface an abstract class and overridden equals/hashCode, but that seemed like it went too far, so I added default methods that implementations of the interface can call. In practice, I doubt these objects are compared for strict equality frequently.

// ByteSequence#equals methods, which raises the question of whether this is valid. I primarily updated this
// code to get testSaveLoadPQ to pass.
for (int i = 0; i < this.count(); i++) {
var thisNode = this.get(i);
var thatNode = that.get(i);
if (thisNode.length() != thatNode.length()) return false;
for (int j = 0; j < thisNode.length(); j++) {
if (thisNode.get(j) != thatNode.get(j)) return false;
}
}
return true;
}

@Override
public int hashCode() {
return Objects.hash(pq, compressedVectors);
// We don't use the array structure in the hash code calculation because we allow for different chunking
// strategies. Instead, we use the first entry in the first chunk to provide a stable hash code.
return Objects.hash(pq, count(), count() > 0 ? get(0).get(0) : 0);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're worried about the performance of hashing millions of vectors, I'd prefer adding the first codepoint from each vector, to using all the first vector's codepoints

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up just hashing based on every value since that is the norm. If it becomes a performance bottleneck, we can address it then--I doubt we compute the hash code frequently.

}

@Override
Expand Down Expand Up @@ -188,7 +220,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}

public ByteSequence<?> get(int ordinal) {
return compressedVectors.get(ordinal);
int chunkIndex = ordinal / vectorsPerChunk;
int vectorIndexInChunk = ordinal % vectorsPerChunk;
int start = vectorIndexInChunk * pq.getSubspaceCount();
return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount());
}

public ProductQuantization getProductQuantization() {
Expand Down Expand Up @@ -225,16 +260,19 @@ public long ramBytesUsed() {
int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;

long codebooksSize = pq.ramBytesUsed();
long listSize = (long) REF_BYTES * (1 + compressedVectors.size());
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size();
return codebooksSize + listSize + dataSize;
long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES;
long dataSize = 0;
for (ByteSequence<?> chunk : compressedDataChunks) {
dataSize += chunk.ramBytesUsed();
}
return codebooksSize + chunksArraySize + dataSize;
}

@Override
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + compressedVectors.size() +
", count=" + vectorCount +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv,

@Override
public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors);
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors, compressedVectors.length, 1);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public byte[] get() {
return data;
}

@Override
public int offset() {
return 0;
}

@Override
public byte get(int n) {
return data[n];
Expand Down Expand Up @@ -72,6 +77,14 @@ public ArrayByteSequence copy() {
return new ArrayByteSequence(Arrays.copyOf(data, data.length));
}

@Override
public ByteSequence<byte[]> slice(int offset, int length) {
if (offset == 0 && length == data.length) {
return this;
}
return new ArraySliceByteSequence(data, offset, length);
}

@Override
public long ramBytesUsed() {
int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import java.util.Arrays;

/**
* A read only {@link ByteSequence} implementation that wraps an array and provides a view into a slice of it.
*/
public class ArraySliceByteSequence implements ByteSequence<byte[]> {
private final byte[] data;
private final int offset;
private final int length;

public ArraySliceByteSequence(byte[] data, int offset, int length) {
if (offset < 0 || length < 0 || offset + length > data.length) {
throw new IllegalArgumentException("Invalid offset or length");
}
this.data = data;
this.offset = offset;
this.length = length;
}

@Override
public byte[] get() {
return data;
}

@Override
public int offset() {
return offset;
}

@Override
public byte get(int n) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
return data[offset + n];
}

@Override
public void set(int n, byte value) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
data[offset + n] = value;
}

@Override
public void setLittleEndianShort(int shortIndex, short value) {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public void zero() {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public int length() {
return length;
}

@Override
public ByteSequence<byte[]> copy() {
byte[] newData = Arrays.copyOfRange(data, offset, offset + length);
return new ArrayByteSequence(newData);
}

@Override
public ByteSequence<byte[]> slice(int sliceOffset, int sliceLength) {
if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) {
throw new IllegalArgumentException("Invalid slice parameters");
}
if (sliceOffset == 0 && sliceLength == length) {
return this;
}
return new ArraySliceByteSequence(data, offset + sliceOffset, sliceLength);
}

@Override
public long ramBytesUsed() {
// Only count the overhead of this slice object, not the underlying array
// since that's shared and counted elsewhere
return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
(3 * Integer.BYTES); // offset, length, and reference to data
}

@Override
public void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int copyLength) {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < Math.min(length, 25); i++) {
sb.append(get(i));
if (i < length - 1) {
sb.append(", ");
}
}
if (length > 25) {
sb.append("...");
}
sb.append("]");
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ArraySliceByteSequence that = (ArraySliceByteSequence) o;
if (this.length != that.length) return false;
for (int i = 0; i < length; i++) {
if (this.get(i) != that.get(i)) return false;
}
return true;
}

@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + get(i);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public interface ByteSequence<T> extends Accountable
*/
T get();

int offset();

int length();

byte get(int i);
Expand All @@ -42,4 +44,6 @@ public interface ByteSequence<T> extends Accountable
void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int length);

ByteSequence<T> copy();

ByteSequence<T> slice(int offset, int length);
}
Loading
Loading