Skip to content

Commit

Permalink
#sdy add missing != operators
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657717519
  • Loading branch information
tomnatan30 authored and copybara-github committed Jul 30, 2024
1 parent 2eeee20 commit 7bb75c2
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions shardy/dialect/sdy/transforms/propagation/sharding_projection.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ struct FactorSharding {
isMinorMost == other.isMinorMost &&
overflowAxes == other.overflowAxes;
}

bool operator!=(const FactorSharding& other) const {
return !(*this == other);
}
};

using FactorIndexToSharding = llvm::DenseMap<int64_t, FactorSharding>;
Expand All @@ -65,6 +69,10 @@ struct TensorFactorShardings {
replicatedAxes == other.replicatedAxes;
}

bool operator!=(const TensorFactorShardings& other) const {
return !(*this == other);
}

// Updates the sharding axes of the given `factorIndex` to `newAxes` if
// 1. this tensor is associated with that factor, and
// 2. `newAxes` strictly contains existing axes. For example, ["a", "b"]
Expand Down Expand Up @@ -177,6 +185,10 @@ class ShardingProjection {
return operands == other.operands && results == other.results;
}

bool operator!=(const ShardingProjection& other) const {
return !(*this == other);
}

private:
SmallVector<TensorFactorShardings> operands;
SmallVector<TensorFactorShardings> results;
Expand Down

0 comments on commit 7bb75c2

Please sign in to comment.