diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h index bac0091b..17748b17 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h @@ -49,6 +49,10 @@ struct FactorSharding { isMinorMost == other.isMinorMost && overflowAxes == other.overflowAxes; } + + bool operator!=(const FactorSharding& other) const { + return !(*this == other); + } }; using FactorIndexToSharding = llvm::DenseMap; @@ -65,6 +69,10 @@ struct TensorFactorShardings { replicatedAxes == other.replicatedAxes; } + bool operator!=(const TensorFactorShardings& other) const { + return !(*this == other); + } + // Updates the sharding axes of the given `factorIndex` to `newAxes` if // 1. this tensor is associated with that factor, and // 2. `newAxes` strictly contains existing axes. For example, ["a", "b"] @@ -177,6 +185,10 @@ class ShardingProjection { return operands == other.operands && results == other.results; } + bool operator!=(const ShardingProjection& other) const { + return !(*this == other); + } + private: SmallVector operands; SmallVector results;