cose
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-25 18:57:12 +02:00
parent be0bcc9dcc
commit 568fd90542
6 changed files with 647 additions and 179 deletions
@@ -1,6 +1,10 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -11,6 +15,107 @@ namespace raptor {
} // namespace raptor
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
struct LowerFragmentAssemblyReconciliatorPattern
: OpConversionPattern<spatial::SpatReconciliatorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
std::optional<StringRef> modeAttr = op.getMode();
if (!modeAttr || *modeAttr != "fragment_assembly")
return failure();
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return op.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = fragmentOperands[operandIndex];
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
currentOutput = tensor::InsertSliceOp::create(rewriter,
op.getLoc(),
fragment,
currentOutput,
getStaticIndexAttrs(rewriter, fragmentOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
rewriter.replaceOp(op, currentOutput);
return success();
}
};
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
}
} // namespace onnx_mlir