automatic code reformat
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-29 19:21:37 +02:00
parent a41f694cf0
commit 2d5b03c08f
26 changed files with 183 additions and 168 deletions
+4 -8
View File
@@ -10,15 +10,11 @@ namespace onnx_mlir {
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp); mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateConstant(mlir::OperationFolder& folder, mlir::Value
mlir::Operation* anchorOp, getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Attribute value,
mlir::Type type);
mlir::Value getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Value
mlir::Operation* anchorOp, getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Attribute value,
mlir::Type type);
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp); mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
+1 -1
View File
@@ -1,5 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
+1 -1
View File
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file); llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags; mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs().enableDebugInfo(true,false); flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
moduleOp.print(os, flags); moduleOp.print(os, flags);
os.flush(); os.flush();
file.close(); file.close();
@@ -1,7 +1,7 @@
#include "AttributeUtils.hpp"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "AttributeUtils.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -1,5 +1,6 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
@@ -11,8 +12,6 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -209,8 +208,7 @@ auto createSpatComputeBatch(RewriterT& rewriter,
block->getArgument(0), block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()), mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()), mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size()) mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>; using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
@@ -252,8 +250,8 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
if (isCompileTimeComputable(input)) if (isCompileTimeComputable(input))
return buildFn(input); return buildFn(input);
auto computeOp = auto computeOp = createSpatCompute<1>(
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) { rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput); mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result); spatial::SpatYieldOp::create(rewriter, loc, result);
}); });
@@ -1,11 +1,10 @@
#include "IndexingUtils.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/APInt.h" #include "llvm/ADT/APInt.h"
#include <algorithm> #include <algorithm>
#include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
using namespace mlir; using namespace mlir;
@@ -85,7 +84,8 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) { Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value)) if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt()); return getOrCreateIndexConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value); return cast<Value>(value);
} }
@@ -26,8 +26,10 @@ mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
mlir::AffineExpr expr, mlir::AffineExpr expr,
mlir::ValueRange operands); mlir::ValueRange operands);
mlir::Value mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier); mlir::Operation* anchorOp,
mlir::Value value,
int64_t multiplier);
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
@@ -8,8 +8,8 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include "ShapeTilingUtils.hpp"
#include "IndexingUtils.hpp" #include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -53,7 +53,9 @@ bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
} }
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); } bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) { int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {}); return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
@@ -98,11 +100,8 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
return permutation; return permutation;
} }
Value transposeMaybeInCompute(Value value, Value transposeMaybeInCompute(
RankedTensorType resultType, Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
ArrayRef<int64_t> permutation,
PatternRewriter& rewriter,
Location loc) {
auto buildTranspose = [&](Value input) -> Value { auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult(); return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
}; };
@@ -127,7 +126,8 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) { static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
auto sourceType = dyn_cast<RankedTensorType>(source.getType()); auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank()) if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()
|| sourceType.getRank() != resultType.getRank())
return false; return false;
for (OpFoldResult stride : strides) { for (OpFoldResult stride : strides) {
@@ -290,7 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
} }
Value lower = zeroIndices[dim]; Value lower = zeroIndices[dim];
Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim)); Value upper =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator}); auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
@@ -316,7 +317,8 @@ Value extractAxisSlice(
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape()); SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset); offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size); sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank())) return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult(); .getResult();
} }
@@ -115,12 +115,8 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter, mlir::Value extractAxisSlice(
mlir::Location loc, mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value source,
int64_t axis,
int64_t offset,
int64_t size);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter, mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc, mlir::Location loc,
@@ -15,8 +15,8 @@
#include "Common/Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -69,11 +69,11 @@ LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
continue; continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) { diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< (allowChannelReceiveInputs << kind << " input #" << currentInputIndex
? " must come from the host or an explicit " << (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive" "spat.channel_receive"
: " must come from the host"); : " must come from the host");
if (definingOp) if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName(); diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
}); });
@@ -135,17 +135,17 @@ LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics; pim::CappedDiagnosticReporter diagnostics;
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void)verifyComputeLikeInputs( (void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics); computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics); verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
} }
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) { for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void)verifyComputeLikeInputs(computeBatchOp.getOperation(), (void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(), computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false, /*allowChannelReceiveInputs=*/false,
"spat.compute_batch", "spat.compute_batch",
diagnostics); diagnostics);
verifyNoExternalTensorCaptures( verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics); computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
} }
@@ -5,9 +5,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedPrePatterns(patterns, ctx); }
populateGeneratedPrePatterns(patterns, ctx);
}
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx); populateGeneratedConversionPatterns(patterns, ctx);
@@ -51,8 +51,8 @@ static Value createPaddedRows(Value tensorValue,
if (tensorType.getDimSize(0) == paddedRows) if (tensorType.getDimSize(0) == paddedRows)
return tensorValue; return tensorValue;
auto paddedType = auto paddedType = RankedTensorType::get(
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding()); {paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)}; rewriter.getIndexAttr(0)};
@@ -62,20 +62,15 @@ static Value createPaddedRows(Value tensorValue,
padBlock->addArgument(rewriter.getIndexType(), loc); padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock); padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock); rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(rewriter, auto zero = getOrCreateConstant(
padOp.getOperation(), rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
rewriter.getZeroAttr(tensorType.getElementType()),
tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero); tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp); rewriter.setInsertionPointAfter(padOp);
return padOp.getResult(); return padOp.getResult();
} }
static Value packRowsForParallelGemm(Value rows, static Value packRowsForParallelGemm(
RankedTensorType rowsType, Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1) if (packFactor == 1)
return rows; return rows;
@@ -118,10 +113,8 @@ static Value unpackRowsFromParallelGemm(Value packedRows,
const int64_t packedNumRows = packedRowsType.getDimSize(0); const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor; const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType = auto expandedType = RankedTensorType::get(
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, {packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
packedRowsType.getElementType(),
packedRowsType.getEncoding());
auto paddedType = auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType = auto unpackedType =
@@ -193,11 +186,8 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
} }
static Value createConvWeightMatrix(Value w, static Value createConvWeightMatrix(
RankedTensorType wFlatType, Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
RankedTensorType wTransType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value { auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter, Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
@@ -360,9 +350,8 @@ static Value createIm2colRowComputes(Value x,
Value im2col = im2colLoop.getResult(0); Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col; Value gemmInputRows = im2col;
if (packFactor != 1) { if (packFactor != 1)
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
}); });
@@ -387,8 +376,13 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
} }
else { else {
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
gemmOut = unpackRowsFromParallelGemm( gemmOut = unpackRowsFromParallelGemm(packedOutput,
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); cast<RankedTensorType>(packedOutput.getType()),
numPatches,
numChannelsOut,
packFactor,
rewriter,
loc);
} }
// Restore to NCHW layout: // Restore to NCHW layout:
@@ -252,7 +252,13 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
Location loc) { Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0); const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch( auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { rewriter,
loc,
TypeRange {partialPiecesType},
laneCount,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows); Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -284,8 +290,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
return *batchOp; return *batchOp;
} }
static Value createDynamicGemmBatchRow( static Value
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) { createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
if (numOutCols == 1) if (numOutCols == 1)
return lane; return lane;
@@ -294,17 +300,21 @@ static Value createDynamicGemmBatchRow(
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
} }
static Value static Value extractDynamicGemmBColumn(
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column}; SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType()); auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc); Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}}; SmallVector<ReassociationIndices> collapseReassociation {
ReassociationIndices {0, 1}
};
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType()); auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
Value collapsed = Value collapsed =
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult(); tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}}; SmallVector<ReassociationIndices> expandReassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult(); return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
} }
@@ -371,13 +381,15 @@ static Value createBroadcastedBiasScalar(Value bias,
Location loc) { Location loc) {
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
if (biasType.getRank() == 1) { if (biasType.getRank() == 1) {
SmallVector<OpFoldResult> offsets { SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)}; : OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType()); auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides) Value vector =
.getResult(); tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}}; SmallVector<ReassociationIndices> reassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult(); return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
} }
@@ -407,16 +419,21 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
const int64_t reductionSize = aType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols; const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = createSpatComputeBatch( auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { rewriter,
loc,
TypeRange {scalarPiecesType},
laneCount,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols); Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc); Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc) : extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -578,9 +595,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced = Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = Value hOffset = onnx_mlir::multiplyIndexByConstant(
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
crossbarSize.getValue());
if (biasArg) { if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset}; SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice = Value biasSlice =
@@ -721,8 +737,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp = createVvdmulBatch( auto batchOp =
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc); createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto outputCompute = createDynamicGemmOutputCompute( auto outputCompute = createDynamicGemmOutputCompute(
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
rewriter.replaceOp(gemmOp, outputCompute.getResults()); rewriter.replaceOp(gemmOp, outputCompute.getResults());
@@ -70,11 +70,8 @@ static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
return keptAxes; return keptAxes;
} }
static Value computeLaneIndex(Value lane, static Value
int64_t stride, computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
int64_t dimSize,
ConversionPatternRewriter& rewriter,
Location loc) {
if (dimSize == 1) if (dimSize == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
@@ -119,35 +116,41 @@ static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
sliceSizes.reserve(inputType.getRank()); sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank()); insertOffsets.reserve(inputType.getRank());
auto batchOp = createSpatComputeBatch( auto batchOp =
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) { createSpatComputeBatch(rewriter,
size_t keptAxisIndex = 0; loc,
sliceOffsets.clear(); TypeRange {batchType},
sliceSizes.clear(); laneCount,
insertOffsets.clear(); {},
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) { ValueRange {input},
if (isReduced) { [&](detail::SpatComputeBatchBodyArgs args) {
sliceOffsets.push_back(rewriter.getIndexAttr(0)); size_t keptAxisIndex = 0;
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis))); sliceOffsets.clear();
continue; sliceSizes.clear();
} insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex = Value axisIndex = computeLaneIndex(
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc); args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex; ++keptAxisIndex;
sliceOffsets.push_back(axisIndex); sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1)); sliceSizes.push_back(rewriter.getIndexAttr(1));
} }
insertOffsets.push_back(args.lane); insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0)); insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice = Value slice = tensor::ExtractSliceOp::create(
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides); rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult(); Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput( createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides); rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
}); });
if (failed(batchOp)) if (failed(batchOp))
return failure(); return failure();
return (*batchOp).getResult(0); return (*batchOp).getResult(0);
@@ -193,15 +196,15 @@ static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
auto reshapeCompute = auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) { createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding()); auto flatType =
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat); Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat; Value compact = flat;
if (compactKeptType != flatType) if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact); compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact; Value keepdims = compact;
if (keepdimsType != compactKeptType) if (keepdimsType != compactKeptType)
keepdims = keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims); spatial::SpatYieldOp::create(rewriter, loc, keepdims);
}); });
return reshapeCompute.getResult(0); return reshapeCompute.getResult(0);
@@ -121,11 +121,9 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
Value transposedInput = Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
result = transposeMaybeInCompute( result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -77,7 +77,7 @@ static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute)
needsRewrite = true; needsRewrite = true;
continue; continue;
keep_input: keep_input:
promoted.newInputs.push_back(input); promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType()); promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc()); promoted.newInputLocs.push_back(input.getLoc());
@@ -127,8 +127,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
Block& oldBlock = compute.getBody().front(); Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute); rewriter.setInsertionPointAfter(compute);
auto newCompute = auto newCompute = spatial::SpatCompute::create(
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs); rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
for (Value weight : promoted->newWeights) { for (Value weight : promoted->newWeights) {
@@ -155,7 +155,12 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
mapper.map(*oldWeightArg, *newWeightArg); mapper.map(*oldWeightArg, *newWeightArg);
} }
if (failed(mapPromotedInputArguments( if (failed(mapPromotedInputArguments(
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter))) compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure(); return failure();
for (Operation& op : oldBlock.without_terminator()) for (Operation& op : oldBlock.without_terminator())
@@ -199,7 +204,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults()); newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
+ compute.getNumResults());
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults()); newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType()); newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc()); newBlockArgLocs.push_back(laneArg->getLoc());
@@ -239,7 +245,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
mapper.map(*oldWeightArg, *newWeightArg); mapper.map(*oldWeightArg, *newWeightArg);
} }
if (failed(mapPromotedInputArguments( if (failed(mapPromotedInputArguments(
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter))) compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure(); return failure();
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) { for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex); auto outputArg = compute.getOutputArgument(resultIndex);
@@ -111,7 +111,8 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} }
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
Value reshaped = materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape); Value reshaped =
materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
rewriter.replaceOp(reshapeOp, reshaped); rewriter.replaceOp(reshapeOp, reshaped);
return success(); return success();
}; };
@@ -44,8 +44,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (isCompileTimeComputable(adaptor.getInput())) { if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) { for (int64_t sliceSize : sliceSizes) {
outputs.push_back( outputs.push_back(extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
offset += sliceSize; offset += sliceSize;
} }
rewriter.replaceOp(splitOp, outputs); rewriter.replaceOp(splitOp, outputs);
@@ -1,5 +1,5 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@@ -104,8 +104,7 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
} }
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
Value transposed = Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation) linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
.getResult()[0];
rewriter.replaceOp(transposeOp, transposed); rewriter.replaceOp(transposeOp, transposed);
return success(); return success();
} }
@@ -7,8 +7,8 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
+1 -1
View File
@@ -1,5 +1,5 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -334,8 +334,8 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
loc, loc,
tensorType, tensorType,
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0), getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
getOrCreateIndexConstant(constantFolder, getOrCreateIndexConstant(
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ), constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize)),
deviceTensor, deviceTensor,
inputTensor, inputTensor,
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize))); rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
@@ -1482,7 +1482,8 @@ void appendScalarSendLoop(MaterializerState& state,
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size())); Value upperBound =
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
@@ -1577,7 +1578,8 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
} }
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size())); Value upperBound =
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
@@ -2342,7 +2344,8 @@ FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state,
} }
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.slots.size())); Value upperBound =
getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.slots.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
@@ -99,17 +99,16 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
.getOutput(); .getOutput();
} }
else { else {
copiedValue = copiedValue = pim::PimMemCopyHostToDevOp::create(
pim::PimMemCopyHostToDevOp::create( rewriter,
rewriter, op->getLoc(),
op->getLoc(), originalType,
originalType, getOrCreateIndexConstant(constantFolder, op, 0),
getOrCreateIndexConstant(constantFolder, op, 0), getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset)),
getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ), deviceDst,
deviceDst, getGlobalOp.getResult(),
getGlobalOp.getResult(), rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) .getOutput();
.getOutput();
} }
cachedByType[originalType] = copiedValue; cachedByType[originalType] = copiedValue;