standardize spatial and pim dialects

remove old unused stuff
This commit is contained in:
NiccoloN
2026-03-23 21:21:31 +01:00
parent 0478d979ff
commit 93e20c1dfc
18 changed files with 693 additions and 1519 deletions

View File

@@ -6,13 +6,11 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {

View File

@@ -194,45 +194,45 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
emitMemCopyOp("ld",
memory.getValueAddress(loadOp.getDeviceDst()),
loadOp.getDeviceDstOffset(),
memory.getValueAddress(loadOp.getHostSrc()),
loadOp.getHostSrcOffset(),
memory.getValueAddress(loadOp.getDeviceTarget()),
loadOp.getDeviceTargetOffset(),
memory.getValueAddress(loadOp.getHostSource()),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
emitMemCopyOp("st",
memory.getValueAddress(storeOp.getHostDst()),
storeOp.getHostDstOffset(),
memory.getValueAddress(storeOp.getDeviceSrc()),
storeOp.getDeviceSrcOffset(),
memory.getValueAddress(storeOp.getHostTarget()),
storeOp.getHostTargetOffset(),
memory.getValueAddress(storeOp.getDeviceSource()),
storeOp.getDeviceSourceOffset(),
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getDst()),
lmvOp.getDstOffset(),
memory.getValueAddress(lmvOp.getSrc()),
lmvOp.getSrcOffset(),
memory.getValueAddress(lmvOp.getTarget()),
lmvOp.getTargetOffset(),
memory.getValueAddress(lmvOp.getSource()),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
emitCommunicationOp("send", memory.getValueAddress(sendOp.getSrc()), sendOp.getTargetCoreId(), sendOp.getSize());
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
emitMvmOp(
mvmId, memory.getValueAddress(mvmLikeOp.getOutBuf()), 0, memory.getValueAddress(mvmLikeOp.getVectorInput()), 0);
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
@@ -243,10 +243,10 @@ static size_t getValueSizeInBytes(mlir::Value value) {
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvaddOp.getA());
auto bAddr = memory.getValueAddress(vvaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvadd";
@@ -254,15 +254,15 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvaddOp.getA());
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvsubOp.getA());
auto bAddr = memory.getValueAddress(vvsubOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvsub";
@@ -270,15 +270,15 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getA());
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmulOp.getA());
auto bAddr = memory.getValueAddress(vvmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvmul";
@@ -286,15 +286,15 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getA());
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvmax";
@@ -302,15 +302,15 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvdmul";
@@ -318,132 +318,71 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
auto aAddr = memory.getValueAddress(vavgOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getA());
json["len"] = getValueSizeInBytes(vavgOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
auto outBufAddr = memory.getValueAddress(vreluOp.getOutBuf());
auto aAddr = memory.getValueAddress(vreluOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vrelu";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getA());
json["len"] = getValueSizeInBytes(vreluOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
auto aAddr = memory.getValueAddress(vtanhOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getA());
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
auto aAddr = memory.getValueAddress(vsigmOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getA());
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const {
auto outBufAddr = memory.getValueAddress(applyFiltersOp.getOutBuf());
auto inBufAddr = memory.getValueAddress(applyFiltersOp.getInput());
auto accumBufAddr = memory.getValueAddress(applyFiltersOp.getAccumBuf());
auto weightIndices = applyFiltersOp.getWeightIndices();
auto inputType = cast<MemRefType>(applyFiltersOp.getInput().getType());
auto outputType = cast<MemRefType>(applyFiltersOp.getOutBuf().getType());
auto inShape = inputType.getShape();
auto outShape = outputType.getShape();
size_t inChannels = inShape[1];
size_t outChannels = outShape[1];
size_t dimX = inShape.size() > 2 ? inShape[2] : 1;
size_t dimY = inShape.size() > 3 ? inShape[3] : 1;
for (size_t outY = 0; outY < dimY; outY++) {
for (size_t outX = 0; outX < dimX; outX++) {
size_t weightIndex = 0;
for (Attribute weight : weightIndices) {
// --- STEP 1: Perform MVMUL operation ---
auto weightId = cast<IntegerAttr>(weight).getInt();
size_t xKer = cast<IntegerAttr>(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt();
size_t yKer = cast<IntegerAttr>(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt();
weightIndex++;
if (outX + xKer >= dimX || outY + yKer >= dimY)
continue;
size_t outputOffset = (outY * dimX + outX) * 32 * outChannels;
size_t inputOffset = ((outY + yKer) * dimX + (outX + xKer)) * 32 * inChannels;
bool isFirstWeight = (weightIndices[0] == weight);
// For the first weight, store directly in output buffer; otherwise use accumulator.
size_t rdAddr = isFirstWeight ? outBufAddr : accumBufAddr;
size_t rdOffset = isFirstWeight ? outputOffset : 0;
emitMvmOp(weightId, rdAddr, rdOffset, inBufAddr, inputOffset);
// --- STEP 2: Perform VADD operation (skip for first weight) ---
if (isFirstWeight)
continue;
// Sum accumulator with output buffer, store result in output buffer.
setupRdRs1Rs2(outBufAddr, outputOffset, accumBufAddr, 0, outBufAddr, outputOffset);
json::Object vaddJson;
vaddJson["op"] = "vvadd";
vaddJson["rd"] = 0;
vaddJson["rs1"] = 1;
vaddJson["rs2"] = 2;
vaddJson["offset"] = createEmptyOffset();
vaddJson["len"] = 32 * outChannels;
emitInstruction(std::move(vaddJson));
}
}
}
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getData());
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
@@ -451,7 +390,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
@@ -570,8 +509,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -592,11 +529,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (isa<pim::PimSumOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation");
continue;
}
else {
op.emitError("Unsupported codegen for this operation");
op.dump();

View File

@@ -99,7 +99,6 @@ public:
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
};

View File

@@ -134,366 +134,4 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
return (*currTensors)[0];
}
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
switch (mapOp) {
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
}
}
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2) {
if (auto unpackedStrides = valuesArray) {
value1 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[0]).getInt();
value2 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[1]).getInt();
}
else {
value1 = 1;
value2 = 1;
}
}
std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y) {
if (valuesArray.has_value()) {
auto pads = mlir::ArrayAttr(*valuesArray);
if (pads.size() != 4)
return "pads must have 4 elements.";
pad_x = cast<IntegerAttr>(pads[2]).getInt();
pad_y = cast<IntegerAttr>(pads[3]).getInt();
}
else {
// Default padding is 0 unless specified otherwise.
// https://onnx.ai/onnx/operators/onnx__Conv.html
pad_x = pad_y = 0;
}
return std::nullopt;
}
void tileImageTensorByChannel(Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
size_t tileSize,
ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = getImageHeight(imageShape);
size_t input_w = getImageWidth(imageShape);
size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
size_t tileRest = getImageChannel(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Location loc = imageTensor.getLoc();
for (size_t i = 0; i < tileCount; i++) {
if (i == tileCount - 1 && tileRest != 0)
sizes[1] = rewriter.getIndexAttr(tileRest);
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
offsets[1] = rewriter.getIndexAttr(i * tileSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
}
}
}
}
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
ConversionPatternRewriter& rewriter,
Location& loc,
Type outputType) {
// Populate the outputTiles for the concat in the given order:
// 1. Start top left pixel
// 2. Continue on its right pixel till the end of the row
// 3. Restart on the next row
size_t outputTileCount = outputTiles.size();
size_t output_w = outputTiles[0].size();
size_t output_h = outputTiles[0][0].size();
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
}
LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) {
if (inX < 0) {
assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding");
return failure();
}
if (inY < 0) {
assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding");
return failure();
}
if ((size_t) inX >= input_w || (size_t) inY >= input_h) {
assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds");
assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds");
return failure();
}
return success();
}
Value createExtractSliceImg(Value valToSlice,
size_t x,
size_t y,
size_t t,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
if (t == channelTileCount - 1 && channelTileRest != 0)
sizes[1] = rewriter.getIndexAttr(channelTileRest);
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
}
Value indexImgValue(Value v,
size_t x,
size_t y,
size_t t,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
ConversionPatternRewriter& rewriter) {
auto newV = rewriter.getRemappedValue(v);
if (newV)
v = newV;
if (!v.getDefiningOp())
return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (auto computeOp = v.getDefiningOp<spatial::SpatWeightedCompute>()) {
// We found the computeOp that produces the tile we want, just return this
// value.
// TODO: Should we assert that x,y,t are zero?
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero");
return v;
}
if (auto receiveOp = v.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
// This is a receiveOp, just return its value which will be resolved later
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero");
return v;
}
if (auto imgConcatOp = v.getDefiningOp<spatial::SpatImgConcatOp>()) {
auto imgConcatInput = imgConcatOp.getInputTile(x, y, t);
// TODO: Is this correct?
// Above we already index exactly the tile we want, so `x=y=t=0` in
// recursive call
return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter);
}
if (auto tensorConcatOp = v.getDefiningOp<tensor::ConcatOp>()) {
// This can be recursive.
// First, get the input tensors of the tensor.concatOp
// Then, find the input tensor that contains the tile we want
// Finally, recursive call asking for the tile
auto concatAxis = tensorConcatOp.getDim();
assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis");
assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis.");
SmallVector<size_t, 4> indexDims = {1, t * crossbarSize, x, y};
// Find the input tensor that contains the tile we want
size_t currentTile = 0;
for (auto concatInput : tensorConcatOp.getInputs()) {
auto concatInputShape = cast<ShapedType>(concatInput.getType());
assert(concatInputShape.getRank() == 4 && "Expecting an image tensor");
auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis);
if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) {
// This input tensor contains the tile we want
indexDims[concatAxis] -= currentTile;
if (indexDims[1] % crossbarSize != 0) {
assert(ignoreConcatError
&& "TODO: Handle non-tile aligned tensor, or set "
"--ignore-concat-error=true");
}
return indexImgValue(concatInput,
indexDims[2],
indexDims[3],
indexDims[1] / crossbarSize,
channelTileCount,
channelTileRest,
input_w,
input_h,
rewriter);
}
currentTile += concatInputSizeOnAxis;
}
assert(false
&& "Could not find the input tensor that contains the tile "
"within tensor.ConcatOp");
}
v.dump();
assert(false && "indexImgValue: unsupported operation");
}
void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Location loc = wholeInputTensor.getLoc();
for (size_t t = 0; t < channelTileCount; t++) {
if (t == channelTileCount - 1 && channelTileRest != 0)
sizes[1] = rewriter.getIndexAttr(channelTileRest);
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
}
}
}
}
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
ConversionPatternRewriter& rewriter) {
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
inputTiles[t][x][y] =
indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
}
}
}
return std::nullopt;
}
LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
const size_t inputTilesCount,
const size_t lastInputTileDimension,
TensorType inputShape,
TensorType outputShape,
Value reshapeInput,
ConversionPatternRewriter& rewriter) {
// Only support reshape between an image and a vector (i.e. flatten)
if (inputShape.getRank() != 4 || outputShape.getRank() != 2) {
return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(),
"resolveVecInputTiles only supports reshapes from 4D to 2D tensors");
}
/*
* From a 4D tensor <N, C, W, H> to a 2D tensor <N, C*H*W>
*/
auto N = inputShape.getDimSize(0);
auto C = inputShape.getDimSize(1);
auto H = inputShape.getDimSize(2);
auto W = inputShape.getDimSize(3);
assert(N == 1 && "Only support N = 1 for image tensors");
for (size_t i = 0; i < inputTilesCount; i++) {
auto c = (i / (H * W)) % C;
// TODO: Is this correct? Or should I invert h and w?
auto w = (i / H) % W;
auto h = i % H;
Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter);
// Assert the shape of the tile, and reshape it
auto curTileShape = cast<TensorType>(curTile.getType());
assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?");
assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?");
assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?");
assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?");
// Reshape this pixel tensor into a vector, for compatibility with the
// rest
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
Value shapeTensor =
arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
size_t coreIndex = i / crossbarCountInCore;
inputTiles[coreIndex].push_back(reshapedCurTile);
}
return success();
}
std::pair<size_t, size_t> kernel_get_start_and_end(
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) {
int64_t firstValid = std::ceil(static_cast<float>(pad) / dilation) * dilation - pad;
int64_t start = std::max(firstValid, out_pos * stride - pad);
int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad);
assert(start >= 0 && "Start position must be non-negative.");
assert(end >= 0 && "End position must be non-negative.");
return std::make_pair(start, end);
}
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) {
auto oldSegmentSizes = wcomputeOp->getAttrOfType<DenseI32ArrayAttr>(wcomputeOp.getOperandSegmentSizesAttrName());
auto newSegmentSizes =
DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment});
wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes);
}
int getResultIndex(Operation* op, Value v) {
int resultNumber = -1;
for (auto result : op->getResults()) {
if (result == v) {
resultNumber = result.getResultNumber();
break;
}
}
assert(resultNumber >= 0 && "Value not found in given operation's results.");
return resultNumber;
}
}; // namespace onnx_mlir

View File

@@ -2,7 +2,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -10,7 +9,6 @@
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
@@ -144,164 +142,4 @@ mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
/**
* Unpacks an optional pair vector into two size_t values.
*
* @param valuesArray The optional `mlir::ArrayAttr` containing the pair of
* values.
* @param value1 The reference to the first `size_t` variable to store the
* unpacked value.
* @param value2 The reference to the second `size_t` variable to store the
* unpacked value.
*/
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2);
/**
* Unpacks the optional pads vector.
*
* @param valuesArray The optional array attribute containing the values.
* @param pad_x The output variable to store the value of pad_x.
* @param pad_y The output variable to store the value of pad_y.
* @param rewriter The rewriter to notify failure
*
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/
std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/**
* Tiles the image tensor by channel.
*
* This function takes an image tensor and tiles it into smaller tiles based on
* the channel dimension. The size of each tile is specified by the tileSize
* parameter.
*
* @param imageTensor The input image tensor (NxCxWxH) to be tiled.
* @param tiles The output tiles vector to store the tiled image tensors.
* @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*/
void tileImageTensorByChannel(mlir::Value imageTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
size_t tileSize,
mlir::ConversionPatternRewriter& rewriter);
/**
* Creates an ImgConcatOp based on the given tiles.
*
* This function takes a 3-dimensional vector `outputTiles` representing the
* tiles to concatenate. The tiles are indexed by [tile][x][y].
*
* @param outputTiles The tiles to concatenate.
* @param rewriter The ConversionPatternRewriter used for creating the
* ImgConcatOp.
* @param loc The location of the operation.
* @param outputType The type of the output tensor.
*
* @return The created ImgConcatOp.
*/
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc,
mlir::Type outputType);
/**
* @brief Verifies if the given input coordinates and padding values are within
* the bounds of the input tensor.
*
* @param input_w The width of the input tensor.
* @param input_h The height of the input tensor.
* @param inX The X-coordinate of the input.
* @param inY The Y-coordinate of the input.
* @param pad_x The padding value in the X-direction.
* @param pad_y The padding value in the Y-direction.
* @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise.
*/
mlir::LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/**
* Resolves the tiling of the input tensor into smaller tiles.
*
* This function takes a whole input tensor and tiles it into smaller tiles
* using the provided parameters. The resulting tiles are stored in the
* `inputTiles` vector.
* Input tiles need to be indexed by:
* a. Channel Tile
* b. Pixel `x` position
* c. Pixel `y` position
* For example: inputTiles[channelTile][x][y]
*
* @param wholeInputTensor The whole input tensor to be tiled.
* @param inputTiles A vector of vectors of vectors of Values representing the
* tiles of the input tensor. The outermost vector represents
* the channels, the middle vector represents the rows, and
* the innermost vector represents the columns of the tiles.
* @param channelTileCount The number of tiles for the `channel` axis.
* @param channelTileRest The size of the last channelTile. Set as 0 if tiles
* fit exactly
* @param input_w The width of the input tensor.
* @param input_h The height of the input tensor.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*
* @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles.
*/
std::optional<llvm::Twine>
resolveImgInputTiles(mlir::Value wholeInputTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
mlir::ConversionPatternRewriter& rewriter);
/**
* Computes the boundaries of an image kernel application.
*
* @param out_pos The position of the output element.
* @param input_width The width of the input image.
* @param krn_width The width of the kernel.
* @param stride The stride value.
* @param dilation The dilation value.
* @param pad The padding value.
* @return A pair of size_t values representing the start and end positions of
* the kernel application.
*/
std::pair<size_t, size_t> kernel_get_start_and_end(
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad);
/**
* @brief Increment the `operandSegmentSizes` in the WeightedCompute operation
* for the `inputs` operand.
*
* This function increments the size of the `inputs` operand segment in the
* `operandSegmentSizes` of the given WeightedCompute operation by the specified
* increment. This is necessary when new operands are programmatically added to
* the WeightedCompute operation.
*
* @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes`
* is to be incremented.
* @param increment The value by which to increment the `inputs` operand segment
* size.
*/
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment);
/**
* @brief Finds the result index of the given operation that produces the
* specified value.
*
* This function takes an operation and a value, and returns the index of the
* result of the operation that corresponds to the given value.
*
* @param op Operation whose result index is to be found.
* @param v The value for which the result index is to be determined.
* @return The index of the result of the operation that produces the specified
* value.
*/
int getResultIndex(mlir::Operation* op, mlir::Value v);
}; // namespace onnx_mlir

View File

@@ -223,9 +223,6 @@ void SpatialToGraphvizPass::runOnOperation() {
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
drawConcatOpSubgraph(concatOp, concatNum++);
}
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
drawConcatOpSubgraph(imgConcatOp, concatNum++);
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
if (producerOp) {

View File

@@ -45,8 +45,4 @@ createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(mlir::Operation* op) {
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir

View File

@@ -45,4 +45,5 @@ def spatToPimVVMaxOp : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -278,7 +278,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
continue;
}
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
if (isa<tensor::ConcatOp>(resultUser)) {
auto concatOp = resultUser;
auto concatValue = concatOp->getResult(0);
auto concatUses = concatValue.getUses();
@@ -368,8 +368,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
};
funcOp.walk([&](PimVMMOp vmmOp) {
auto outTensorOperand = vmmOp.getOutBuf();
auto resultTensor = vmmOp.getOutRes();
auto outTensorOperand = vmmOp.getOutputBuffer();
auto resultTensor = vmmOp.getOutput();
auto outShape = getTensorShape(outTensorOperand);
assert(isHVectorShape(outShape));
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
@@ -602,9 +602,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
// If the operand is a concatenation operation and the returnOp was the only
// user of the returnOperand, we can safely remove it
if (isAConcatOp(returnOperand)) {
if (isa<tensor::ConcatOp>(returnOperand)) {
auto returnOperandUses = it.value().getUses();
if (rangeLength(returnOperandUses) == 0)
rewriter.eraseOp(returnOperand);
@@ -632,7 +630,7 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
// user. This means that we need to get the replace the original SendOp with
// a BroadcastSendOp
rewriter.setInsertionPoint(sendOp);
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getInput());
}
}

View File

@@ -20,45 +20,12 @@ class PimOp<string mnemonic, list<Trait> traits = []> :
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
// Communication
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
PimTensor: $src,
I32Attr: $size,
I32Attr: $targetCoreId
);
let assemblyFormat = [{
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
}];
}
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor: $dst,
I32Attr: $size,
I32Attr: $srcCoreId
);
let results = (outs
PimTensor: $out
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
}];
}
// Core
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock]> {
let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body);
@@ -72,412 +39,443 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
}];
}
// Memory
def PimConstantOp: PimOp<"constant", []> {
let description = [{
Allocate a constant value in global memory
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
PimTensor: $out
);
}
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from host memory into device memory
}];
let arguments = (ins
PimTensor: $deviceDst,
PimTensor: $hostSrc,
I32Attr: $deviceDstOffset,
I32Attr: $hostSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $deviceDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceDstMutable();
}
}];
let assemblyFormat = [{
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
}];
}
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from device memory into host memory
}];
let arguments = (ins
PimTensor: $hostDst,
PimTensor: $deviceSrc,
I32Attr: $hostDstOffset,
I32Attr: $deviceSrcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $hostDstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostDstMutable();
}
}];
let assemblyFormat = [{
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
}];
}
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from and to the same memory
}];
let arguments = (ins
PimTensor: $dst,
PimTensor: $src,
I32Attr: $dstOffset,
I32Attr: $srcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $dstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
}];
}
// Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
Vector-matrix multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
let description = [{
Matrix-vector multiplication: c = a * b
}];
let arguments = (ins
I32Attr: $weightIndex,
PimTensor: $vectorInput,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
}
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{
Element-wise max: c = max(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Apply filters to a tensor
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
PimTensor: $input,
PimTensor: $outBuf,
PimTensor: $accumBuf
);
let results = (outs
PimTensor: $outRes
);
let assemblyFormat = [{
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
}];
}
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Sum all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Average all elements into a single one
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise ReLU: c = max(a, 0)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimHaltOp : PimOp<"halt", [Terminator]> {
let description = [{
Halts the execution of the core
}];
let summary = "Halt execution of the core";
let assemblyFormat = [{
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
def PimSendOp : PimOp<"send", []> {
let summary = "Send a tensor to another core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
I32Attr:$targetCoreId
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
}];
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
I32Attr:$sourceCoreId
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
}];
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
let arguments = (ins
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceTargetMutable();
}
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
}];
}
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory";
let arguments = (ins
PimTensor:$hostTarget,
PimTensor:$deviceSource,
I32Attr:$hostTargetOffset,
I32Attr:$deviceSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getHostTargetMutable();
}
}];
let assemblyFormat = [{
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
}];
}
def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region within the same memory space";
let arguments = (ins
PimTensor:$target,
PimTensor:$source,
I32Attr:$targetOffset,
I32Attr:$sourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getTargetMutable();
}
}];
let assemblyFormat = [{
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> {
let summary = "Transpose a matrix";
let arguments = (ins
PimTensor:$input,
I64ArrayAttr:$permutation,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let summary = "Vector-matrix multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
let summary = "Matrix-vector multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
let summary = "Element-wise addition: c = a + b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> {
let summary = "Element-wise subtraction: c = a - b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> {
let summary = "Element-wise multiplication: c = a * b";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> {
let summary = "Element-wise max: c = max(a, b)";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
let summary = "Dot product: c = dot(a, b)";
let arguments = (ins
PimTensor:$lhs,
PimTensor:$rhs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
let summary = "Reduce all elements to a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
let summary = "Average all elements into a single value";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> {
let summary = "Element-wise ReLU: c = max(a, 0)";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> {
let summary = "Element-wise tanh activation";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
#endif // PIM_DIALECT_H

View File

@@ -25,19 +25,6 @@ void PimDialect::initialize() {
>();
}
#define POPULATE_DEPENDENCIES(OP_NAME) \
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim
} // namespace onnx_mlir

View File

@@ -30,7 +30,7 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
.getOutput();
}
struct MemCopyHostToDevOpInterface
@@ -40,26 +40,26 @@ struct MemCopyHostToDevOpInterface
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
auto hostSource = memCopyHostToDevOp.getHostSource();
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
if (failed(deviceDstOpt))
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
if (failed(deviceTargetOpt))
return failure();
auto deviceDstMemRef = *deviceDstOpt;
auto deviceTargetMemRef = *deviceTargetOpt;
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
if (failed(hostSrcOpt))
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
if (failed(hostSourceOpt))
return failure();
auto hostSrcMemRef = *hostSrcOpt;
auto hostSourceMemRef = *hostSourceOpt;
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceDstMemRef.getType(),
deviceDstMemRef,
hostSrcMemRef,
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
memCopyHostToDevOp.getHostSrcOffsetAttr(),
deviceTargetMemRef.getType(),
deviceTargetMemRef,
hostSourceMemRef,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
@@ -73,25 +73,25 @@ struct MemCopyDevToHostOpInterface
BufferizationState& state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
if (failed(globalDstOpt))
auto hostTarget = memCopyDevToHostOp.getHostTarget();
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
if (failed(hostTargetOpt))
return failure();
auto globalDstMemRef = *globalDstOpt;
auto hostTargetMemRef = *hostTargetOpt;
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
if (failed(localSrcOpt))
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
if (failed(deviceSourceOpt))
return failure();
auto localSrcMemRef = *localSrcOpt;
auto deviceSourceMemRef = *deviceSourceOpt;
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
globalDstMemRef.getType(),
globalDstMemRef,
localSrcMemRef,
memCopyDevToHostOp.getHostDstOffsetAttr(),
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
hostTargetMemRef.getType(),
hostTargetMemRef,
deviceSourceMemRef,
memCopyDevToHostOp.getHostTargetOffsetAttr(),
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
@@ -109,16 +109,16 @@ struct TransposeOpBufferizeInterface
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
return success();
}
};
@@ -132,9 +132,9 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
auto vmmOp = cast<PimVMMOp>(op);
Value readVal = uRead->get();
Value writeVal = uWrite->get();
if (writeVal != vmmOp.getOutBuf())
if (writeVal != vmmOp.getOutputBuffer())
return false;
if (readVal == vmmOp.getVectorInput())
if (readVal == vmmOp.getInput())
if (state.areEquivalentBufferizedValues(readVal, writeVal))
return true;
return false;
@@ -146,16 +146,16 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};
@@ -171,16 +171,16 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
if (failed(vectorInputOpt))
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};
@@ -203,22 +203,23 @@ struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<B
BufferizationState& state) const {
auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt))
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
if (failed(lhsOpt))
return failure();
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt))
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
if (failed(rhsOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt))
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
return success();
}
};

View File

@@ -16,4 +16,5 @@ def memrefCopyToPimMemCopyOp : Pat<
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
//===----------------------------------------------------------------------===//
// Execution
//===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute operation, with constant weights already attached";
let summary = "Compute region with attached constant weights";
let arguments = (ins
Variadic<SpatTensor>:$weights,
@@ -50,6 +54,8 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
let arguments = (ins
Variadic<SpatTensor>:$outputs
);
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
}
//===----------------------------------------------------------------------===//
// Data movement operations
// Communication
//===----------------------------------------------------------------------===//
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$new_channel
SpatChannelType:$channel
);
let builders = [
@@ -80,107 +88,73 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
}
def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel";
let arguments = (ins
SpatChannelType:$channel,
SpatTensor: $data
SpatTensor:$input
);
let assemblyFormat = [{
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel";
let arguments = (ins
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
}
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
let summary = "Broadcast a tensor through a shared channel buffer";
let arguments = (ins
SpatChannelType:$channel,
SpatTensor: $data
SpatTensor:$input
);
let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
}];
}
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
let summary = "Receive a tensor from a shared channel buffer";
let arguments = (ins
SpatChannelType:$channel
);
let results = (outs
SpatTensor: $data
SpatTensor:$output
);
}
//===----------------------------------------------------------------------===//
// Math operations
//===----------------------------------------------------------------------===//
def SpatConstantOp: SpatOp<"constant", []> {
let description = [{
"Constant value, should be used for weights and biases"
let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
}];
let arguments = (ins
AnyAttr: $value,
BoolAttr: $shouldAllocate
);
let results = (outs
SpatTensor: $out
);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
let arguments = (ins
I32Attr: $weightIndex,
SpatTensor:$vector
);
let results = (outs
SpatTensor:$output
);
// TODO: Verifier that checks it is within a WeightedCompute operation,
// that the weightIndex is valid, and that the matrix is of the right size.
let hasVerifier = 1;
}
def SpatVAddOp: SpatOp<"vadd", []> {
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
SpatTensor:$input
);
let results = (outs
@@ -190,65 +164,67 @@ def SpatVAddOp: SpatOp<"vadd", []> {
let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVMulOp : SpatOp<"vmul", []> {
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVDivOp: SpatOp<"vdiv", []> {
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
let arguments = (ins
SpatTensor:$a,
SpatTensor:$b
);
let results = (outs
SpatTensor:$output
);
//let hasVerifier = 1;
let assemblyFormat = [{
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
}];
}
//TODO: remove
def SpatVSDivOp: SpatOp<"vsdiv", []> {
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
let arguments = (ins
SpatTensor:$dividend,
SpatTensor:$divisor
);
let results = (outs
SpatTensor:$output
);
}
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
@@ -257,9 +233,15 @@ def SpatSumOp: SpatOp<"sum", []> {
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";
let arguments = (ins
SpatTensor:$input
);
@@ -267,9 +249,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";
let arguments = (ins
SpatTensor:$input
);
@@ -277,15 +265,18 @@ def SpatReluOp: SpatOp<"relu", []> {
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVMaxOp : SpatOp<"vmax", []> {
let summary = "Element-wise max function";
let summary = "Element-wise max between two tensors";
let arguments = (ins
SpatTensor: $a,
SpatTensor: $b
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
@@ -293,62 +284,9 @@ def SpatVMaxOp: SpatOp<"vmax", []> {
);
let hasVerifier = 1;
}
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
let description = [{
Applies a variable number of crossbar weights to a single large image tensor tile,
producing a corresponding output tile. This essentially encapsulates a big for loop
over all pixels in the input tile, where each pixel is multiplied by all the weights
in the operation.
}];
let arguments = (ins
I64ArrayAttr: $weightIndices,
I64ArrayAttr: $xKernelPositions,
I64ArrayAttr: $yKernelPositions,
SpatTensor: $input
);
let results = (outs SpatTensor);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type(results)
}];
}
//===----------------------------------------------------------------------===//
// Other operations
//===----------------------------------------------------------------------===//
def SpatImgConcatOp: SpatOp<"img_concat", []> {
let summary = "Concatenate pixel tiles into a single image";
let description = [{
Concatenate pixel tiles into a single image:
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
The input tiles should be provided in a specific order:
start from the top left pixel,
then continue with the pixel on its right,
and once you finish the first row of pixels, go to the next row.
}];
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let extraClassDeclaration = [{
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}

View File

@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
.getOutput();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
@@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
memrefOperands.push_back(outputTensor);
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutRes();
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOut();
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
return success();
}
@@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, S
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
// One operand ($input) is read from. All other inputs are only written to.
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 0;
}
// One input ($accumBuf) is written to. All other inputs are only read.
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// Operand 0: $input
// Operand 1: $outBuf
// Operand 2: $accumBuf
return opOperand.getOperandNumber() == 2;
}
// No operands are aliased with any other operands.
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Bufferize the operation.
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(inputBuffer))
return failure();
// Create a new buffer for the output tensor.
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
// Create a new buffer for the accumulation buffer.
// To do this, create a new allocation operation. Size must be axbx1x1,
// where axbxcxd is the size of the output tensor. Since the shape is
// different, we can't immediately use createEmptyFromType, we first need to
// create the shape of the accumulation buffer.
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
// Set the last two dimensions to 1.
accumShape[accumShape.size() - 1] = 1;
accumShape[accumShape.size() - 2] = 1;
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
// Bufferize the operation.
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
// Replace the operation with the bufferized value.
replaceOpWithBufferizedValues(rewriter, op, bufferized);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
@@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
});
}

View File

@@ -247,11 +247,11 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
@@ -268,8 +268,8 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.reserve(transposeOp.getPermutation().size());
for (IntegerAttr attr : transposeOp.getPermutation().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
@@ -389,18 +389,18 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)

View File

@@ -89,10 +89,10 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDst(),
copyOp.getSrc(),
copyOp.getDstOffset(),
copyOp.getSrcOffset(),
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
@@ -114,7 +114,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDst());
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
}
};
@@ -125,10 +125,10 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceDst(),
copyOp.getHostSrc(),
copyOp.getDeviceDstOffset(),
copyOp.getHostSrcOffset(),
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(),
copyOp.getHostSourceOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
@@ -150,7 +150,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
}
};