standardize spatial and pim dialects
remove old unused stuff
This commit is contained in:
@@ -6,13 +6,11 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
|
|
||||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -194,45 +194,45 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
|||||||
|
|
||||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
|
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
|
||||||
emitMemCopyOp("ld",
|
emitMemCopyOp("ld",
|
||||||
memory.getValueAddress(loadOp.getDeviceDst()),
|
memory.getValueAddress(loadOp.getDeviceTarget()),
|
||||||
loadOp.getDeviceDstOffset(),
|
loadOp.getDeviceTargetOffset(),
|
||||||
memory.getValueAddress(loadOp.getHostSrc()),
|
memory.getValueAddress(loadOp.getHostSource()),
|
||||||
loadOp.getHostSrcOffset(),
|
loadOp.getHostSourceOffset(),
|
||||||
loadOp.getSize());
|
loadOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
||||||
emitMemCopyOp("st",
|
emitMemCopyOp("st",
|
||||||
memory.getValueAddress(storeOp.getHostDst()),
|
memory.getValueAddress(storeOp.getHostTarget()),
|
||||||
storeOp.getHostDstOffset(),
|
storeOp.getHostTargetOffset(),
|
||||||
memory.getValueAddress(storeOp.getDeviceSrc()),
|
memory.getValueAddress(storeOp.getDeviceSource()),
|
||||||
storeOp.getDeviceSrcOffset(),
|
storeOp.getDeviceSourceOffset(),
|
||||||
storeOp.getSize());
|
storeOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
||||||
emitMemCopyOp("lmv",
|
emitMemCopyOp("lmv",
|
||||||
memory.getValueAddress(lmvOp.getDst()),
|
memory.getValueAddress(lmvOp.getTarget()),
|
||||||
lmvOp.getDstOffset(),
|
lmvOp.getTargetOffset(),
|
||||||
memory.getValueAddress(lmvOp.getSrc()),
|
memory.getValueAddress(lmvOp.getSource()),
|
||||||
lmvOp.getSrcOffset(),
|
lmvOp.getSourceOffset(),
|
||||||
lmvOp.getSize(),
|
lmvOp.getSize(),
|
||||||
"len");
|
"len");
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
||||||
emitCommunicationOp(
|
emitCommunicationOp(
|
||||||
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
|
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
|
||||||
emitCommunicationOp("send", memory.getValueAddress(sendOp.getSrc()), sendOp.getTargetCoreId(), sendOp.getSize());
|
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
|
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
|
||||||
emitMvmOp(
|
emitMvmOp(
|
||||||
mvmId, memory.getValueAddress(mvmLikeOp.getOutBuf()), 0, memory.getValueAddress(mvmLikeOp.getVectorInput()), 0);
|
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
|
||||||
|
|
||||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||||
}
|
}
|
||||||
@@ -243,10 +243,10 @@ static size_t getValueSizeInBytes(mlir::Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vvaddOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vvaddOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvadd";
|
json["op"] = "vvadd";
|
||||||
@@ -254,15 +254,15 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vvaddOp.getA());
|
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vvsubOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vvsubOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvsub";
|
json["op"] = "vvsub";
|
||||||
@@ -270,15 +270,15 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vvsubOp.getA());
|
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vvmulOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vvmulOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvmul";
|
json["op"] = "vvmul";
|
||||||
@@ -286,15 +286,15 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vvmulOp.getA());
|
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvmax";
|
json["op"] = "vvmax";
|
||||||
@@ -302,15 +302,15 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
|
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvdmul";
|
json["op"] = "vvdmul";
|
||||||
@@ -318,132 +318,71 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
|
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vavgOp.getA());
|
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
|
||||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vavg";
|
json["op"] = "vavg";
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vavgOp.getA());
|
json["len"] = getValueSizeInBytes(vavgOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vreluOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vreluOp.getA());
|
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
|
||||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vrelu";
|
json["op"] = "vrelu";
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vreluOp.getA());
|
json["len"] = getValueSizeInBytes(vreluOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vtanhOp.getA());
|
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
|
||||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vtanh";
|
json["op"] = "vtanh";
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vtanhOp.getA());
|
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vsigmOp.getA());
|
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
|
||||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vsigm";
|
json["op"] = "vsigm";
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = getValueSizeInBytes(vsigmOp.getA());
|
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const {
|
|
||||||
auto outBufAddr = memory.getValueAddress(applyFiltersOp.getOutBuf());
|
|
||||||
auto inBufAddr = memory.getValueAddress(applyFiltersOp.getInput());
|
|
||||||
auto accumBufAddr = memory.getValueAddress(applyFiltersOp.getAccumBuf());
|
|
||||||
|
|
||||||
auto weightIndices = applyFiltersOp.getWeightIndices();
|
|
||||||
|
|
||||||
auto inputType = cast<MemRefType>(applyFiltersOp.getInput().getType());
|
|
||||||
auto outputType = cast<MemRefType>(applyFiltersOp.getOutBuf().getType());
|
|
||||||
auto inShape = inputType.getShape();
|
|
||||||
auto outShape = outputType.getShape();
|
|
||||||
|
|
||||||
size_t inChannels = inShape[1];
|
|
||||||
size_t outChannels = outShape[1];
|
|
||||||
size_t dimX = inShape.size() > 2 ? inShape[2] : 1;
|
|
||||||
size_t dimY = inShape.size() > 3 ? inShape[3] : 1;
|
|
||||||
|
|
||||||
for (size_t outY = 0; outY < dimY; outY++) {
|
|
||||||
for (size_t outX = 0; outX < dimX; outX++) {
|
|
||||||
|
|
||||||
size_t weightIndex = 0;
|
|
||||||
for (Attribute weight : weightIndices) {
|
|
||||||
// --- STEP 1: Perform MVMUL operation ---
|
|
||||||
auto weightId = cast<IntegerAttr>(weight).getInt();
|
|
||||||
size_t xKer = cast<IntegerAttr>(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt();
|
|
||||||
size_t yKer = cast<IntegerAttr>(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt();
|
|
||||||
weightIndex++;
|
|
||||||
|
|
||||||
if (outX + xKer >= dimX || outY + yKer >= dimY)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
size_t outputOffset = (outY * dimX + outX) * 32 * outChannels;
|
|
||||||
size_t inputOffset = ((outY + yKer) * dimX + (outX + xKer)) * 32 * inChannels;
|
|
||||||
|
|
||||||
bool isFirstWeight = (weightIndices[0] == weight);
|
|
||||||
|
|
||||||
// For the first weight, store directly in output buffer; otherwise use accumulator.
|
|
||||||
size_t rdAddr = isFirstWeight ? outBufAddr : accumBufAddr;
|
|
||||||
size_t rdOffset = isFirstWeight ? outputOffset : 0;
|
|
||||||
emitMvmOp(weightId, rdAddr, rdOffset, inBufAddr, inputOffset);
|
|
||||||
|
|
||||||
// --- STEP 2: Perform VADD operation (skip for first weight) ---
|
|
||||||
if (isFirstWeight)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Sum accumulator with output buffer, store result in output buffer.
|
|
||||||
setupRdRs1Rs2(outBufAddr, outputOffset, accumBufAddr, 0, outBufAddr, outputOffset);
|
|
||||||
|
|
||||||
json::Object vaddJson;
|
|
||||||
vaddJson["op"] = "vvadd";
|
|
||||||
vaddJson["rd"] = 0;
|
|
||||||
vaddJson["rs1"] = 1;
|
|
||||||
vaddJson["rs2"] = 2;
|
|
||||||
vaddJson["offset"] = createEmptyOffset();
|
|
||||||
vaddJson["len"] = 32 * outChannels;
|
|
||||||
emitInstruction(std::move(vaddJson));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||||
auto srcAddr = memory.getValueAddress(transposeOp.getData());
|
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
|
||||||
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
|
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
|
||||||
|
|
||||||
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
|
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||||
auto srcShape = srcType.getShape();
|
auto srcShape = srcType.getShape();
|
||||||
size_t rank = srcShape.size();
|
size_t rank = srcShape.size();
|
||||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||||
@@ -451,7 +390,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
|||||||
|
|
||||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||||
SmallVector<int64_t> perm =
|
SmallVector<int64_t> perm =
|
||||||
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||||
|
|
||||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||||
SmallVector<int64_t> dstShape(rank);
|
SmallVector<int64_t> dstShape(rank);
|
||||||
@@ -570,8 +509,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
||||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
||||||
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
|
|
||||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
|
||||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
@@ -592,11 +529,6 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||||
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||||
else if (isa<pim::PimSumOp>(op)) {
|
|
||||||
// TODO: Implement somehow?
|
|
||||||
op.emitWarning("Operation is not yet supported in code generation");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
op.dump();
|
||||||
|
|||||||
@@ -99,7 +99,6 @@ public:
|
|||||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -134,366 +134,4 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
|||||||
return (*currTensors)[0];
|
return (*currTensors)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
|
|
||||||
switch (mapOp) {
|
|
||||||
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
|
|
||||||
case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2) {
|
|
||||||
if (auto unpackedStrides = valuesArray) {
|
|
||||||
value1 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[0]).getInt();
|
|
||||||
value2 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[1]).getInt();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
value1 = 1;
|
|
||||||
value2 = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<llvm::Twine>
|
|
||||||
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y) {
|
|
||||||
if (valuesArray.has_value()) {
|
|
||||||
auto pads = mlir::ArrayAttr(*valuesArray);
|
|
||||||
if (pads.size() != 4)
|
|
||||||
return "pads must have 4 elements.";
|
|
||||||
|
|
||||||
pad_x = cast<IntegerAttr>(pads[2]).getInt();
|
|
||||||
pad_y = cast<IntegerAttr>(pads[3]).getInt();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Default padding is 0 unless specified otherwise.
|
|
||||||
// https://onnx.ai/onnx/operators/onnx__Conv.html
|
|
||||||
pad_x = pad_y = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
void tileImageTensorByChannel(Value imageTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
|
|
||||||
size_t tileSize,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
|
|
||||||
|
|
||||||
size_t input_h = getImageHeight(imageShape);
|
|
||||||
size_t input_w = getImageWidth(imageShape);
|
|
||||||
size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
|
|
||||||
size_t tileRest = getImageChannel(imageShape) % tileSize;
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
Location loc = imageTensor.getLoc();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tileCount; i++) {
|
|
||||||
if (i == tileCount - 1 && tileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(tileRest);
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
offsets[1] = rewriter.getIndexAttr(i * tileSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location& loc,
|
|
||||||
Type outputType) {
|
|
||||||
// Populate the outputTiles for the concat in the given order:
|
|
||||||
// 1. Start top left pixel
|
|
||||||
// 2. Continue on its right pixel till the end of the row
|
|
||||||
// 3. Restart on the next row
|
|
||||||
size_t outputTileCount = outputTiles.size();
|
|
||||||
size_t output_w = outputTiles[0].size();
|
|
||||||
size_t output_h = outputTiles[0][0].size();
|
|
||||||
SmallVector<Value> tilesToConcat;
|
|
||||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
|
||||||
for (size_t outX = 0; outX < output_h; outX++)
|
|
||||||
for (size_t outY = 0; outY < output_w; outY++)
|
|
||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
|
||||||
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
|
|
||||||
|
|
||||||
return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) {
|
|
||||||
|
|
||||||
if (inX < 0) {
|
|
||||||
assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inY < 0) {
|
|
||||||
assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((size_t) inX >= input_w || (size_t) inY >= input_h) {
|
|
||||||
assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds");
|
|
||||||
assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createExtractSliceImg(Value valToSlice,
|
|
||||||
size_t x,
|
|
||||||
size_t y,
|
|
||||||
size_t t,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
if (t == channelTileCount - 1 && channelTileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(channelTileRest);
|
|
||||||
|
|
||||||
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value indexImgValue(Value v,
|
|
||||||
size_t x,
|
|
||||||
size_t y,
|
|
||||||
size_t t,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
|
|
||||||
auto newV = rewriter.getRemappedValue(v);
|
|
||||||
if (newV)
|
|
||||||
v = newV;
|
|
||||||
|
|
||||||
if (!v.getDefiningOp())
|
|
||||||
return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
|
|
||||||
if (auto computeOp = v.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
|
||||||
// We found the computeOp that produces the tile we want, just return this
|
|
||||||
// value.
|
|
||||||
// TODO: Should we assert that x,y,t are zero?
|
|
||||||
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero");
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveOp = v.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
|
||||||
// This is a receiveOp, just return its value which will be resolved later
|
|
||||||
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero");
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto imgConcatOp = v.getDefiningOp<spatial::SpatImgConcatOp>()) {
|
|
||||||
auto imgConcatInput = imgConcatOp.getInputTile(x, y, t);
|
|
||||||
// TODO: Is this correct?
|
|
||||||
// Above we already index exactly the tile we want, so `x=y=t=0` in
|
|
||||||
// recursive call
|
|
||||||
|
|
||||||
return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto tensorConcatOp = v.getDefiningOp<tensor::ConcatOp>()) {
|
|
||||||
// This can be recursive.
|
|
||||||
// First, get the input tensors of the tensor.concatOp
|
|
||||||
// Then, find the input tensor that contains the tile we want
|
|
||||||
// Finally, recursive call asking for the tile
|
|
||||||
auto concatAxis = tensorConcatOp.getDim();
|
|
||||||
assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis");
|
|
||||||
assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis.");
|
|
||||||
SmallVector<size_t, 4> indexDims = {1, t * crossbarSize, x, y};
|
|
||||||
|
|
||||||
// Find the input tensor that contains the tile we want
|
|
||||||
size_t currentTile = 0;
|
|
||||||
for (auto concatInput : tensorConcatOp.getInputs()) {
|
|
||||||
auto concatInputShape = cast<ShapedType>(concatInput.getType());
|
|
||||||
assert(concatInputShape.getRank() == 4 && "Expecting an image tensor");
|
|
||||||
auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis);
|
|
||||||
|
|
||||||
if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) {
|
|
||||||
// This input tensor contains the tile we want
|
|
||||||
indexDims[concatAxis] -= currentTile;
|
|
||||||
if (indexDims[1] % crossbarSize != 0) {
|
|
||||||
assert(ignoreConcatError
|
|
||||||
&& "TODO: Handle non-tile aligned tensor, or set "
|
|
||||||
"--ignore-concat-error=true");
|
|
||||||
}
|
|
||||||
return indexImgValue(concatInput,
|
|
||||||
indexDims[2],
|
|
||||||
indexDims[3],
|
|
||||||
indexDims[1] / crossbarSize,
|
|
||||||
channelTileCount,
|
|
||||||
channelTileRest,
|
|
||||||
input_w,
|
|
||||||
input_h,
|
|
||||||
rewriter);
|
|
||||||
}
|
|
||||||
currentTile += concatInputSizeOnAxis;
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(false
|
|
||||||
&& "Could not find the input tensor that contains the tile "
|
|
||||||
"within tensor.ConcatOp");
|
|
||||||
}
|
|
||||||
|
|
||||||
v.dump();
|
|
||||||
|
|
||||||
assert(false && "indexImgValue: unsupported operation");
|
|
||||||
}
|
|
||||||
|
|
||||||
void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Location loc = wholeInputTensor.getLoc();
|
|
||||||
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
if (t == channelTileCount - 1 && channelTileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(channelTileRest);
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
inputTiles[t][x][y] =
|
|
||||||
indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
|
|
||||||
const size_t inputTilesCount,
|
|
||||||
const size_t lastInputTileDimension,
|
|
||||||
TensorType inputShape,
|
|
||||||
TensorType outputShape,
|
|
||||||
Value reshapeInput,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
// Only support reshape between an image and a vector (i.e. flatten)
|
|
||||||
if (inputShape.getRank() != 4 || outputShape.getRank() != 2) {
|
|
||||||
return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(),
|
|
||||||
"resolveVecInputTiles only supports reshapes from 4D to 2D tensors");
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* From a 4D tensor <N, C, W, H> to a 2D tensor <N, C*H*W>
|
|
||||||
*/
|
|
||||||
auto N = inputShape.getDimSize(0);
|
|
||||||
auto C = inputShape.getDimSize(1);
|
|
||||||
auto H = inputShape.getDimSize(2);
|
|
||||||
auto W = inputShape.getDimSize(3);
|
|
||||||
assert(N == 1 && "Only support N = 1 for image tensors");
|
|
||||||
|
|
||||||
for (size_t i = 0; i < inputTilesCount; i++) {
|
|
||||||
auto c = (i / (H * W)) % C;
|
|
||||||
// TODO: Is this correct? Or should I invert h and w?
|
|
||||||
auto w = (i / H) % W;
|
|
||||||
auto h = i % H;
|
|
||||||
|
|
||||||
Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter);
|
|
||||||
|
|
||||||
// Assert the shape of the tile, and reshape it
|
|
||||||
auto curTileShape = cast<TensorType>(curTile.getType());
|
|
||||||
assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?");
|
|
||||||
assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?");
|
|
||||||
assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?");
|
|
||||||
assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?");
|
|
||||||
|
|
||||||
// Reshape this pixel tensor into a vector, for compatibility with the
|
|
||||||
// rest
|
|
||||||
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
|
|
||||||
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
|
|
||||||
Value shapeTensor =
|
|
||||||
arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
|
||||||
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
|
|
||||||
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
|
|
||||||
|
|
||||||
size_t coreIndex = i / crossbarCountInCore;
|
|
||||||
inputTiles[coreIndex].push_back(reshapedCurTile);
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<size_t, size_t> kernel_get_start_and_end(
|
|
||||||
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) {
|
|
||||||
int64_t firstValid = std::ceil(static_cast<float>(pad) / dilation) * dilation - pad;
|
|
||||||
int64_t start = std::max(firstValid, out_pos * stride - pad);
|
|
||||||
int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad);
|
|
||||||
|
|
||||||
assert(start >= 0 && "Start position must be non-negative.");
|
|
||||||
assert(end >= 0 && "End position must be non-negative.");
|
|
||||||
return std::make_pair(start, end);
|
|
||||||
}
|
|
||||||
|
|
||||||
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) {
|
|
||||||
auto oldSegmentSizes = wcomputeOp->getAttrOfType<DenseI32ArrayAttr>(wcomputeOp.getOperandSegmentSizesAttrName());
|
|
||||||
|
|
||||||
auto newSegmentSizes =
|
|
||||||
DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment});
|
|
||||||
|
|
||||||
wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes);
|
|
||||||
}
|
|
||||||
|
|
||||||
int getResultIndex(Operation* op, Value v) {
|
|
||||||
int resultNumber = -1;
|
|
||||||
for (auto result : op->getResults()) {
|
|
||||||
if (result == v) {
|
|
||||||
resultNumber = result.getResultNumber();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert(resultNumber >= 0 && "Value not found in given operation's results.");
|
|
||||||
|
|
||||||
return resultNumber;
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
}; // namespace onnx_mlir
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
@@ -10,7 +9,6 @@
|
|||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -144,164 +142,4 @@ mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
|||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unpacks an optional pair vector into two size_t values.
|
|
||||||
*
|
|
||||||
* @param valuesArray The optional `mlir::ArrayAttr` containing the pair of
|
|
||||||
* values.
|
|
||||||
* @param value1 The reference to the first `size_t` variable to store the
|
|
||||||
* unpacked value.
|
|
||||||
* @param value2 The reference to the second `size_t` variable to store the
|
|
||||||
* unpacked value.
|
|
||||||
*/
|
|
||||||
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unpacks the optional pads vector.
|
|
||||||
*
|
|
||||||
* @param valuesArray The optional array attribute containing the values.
|
|
||||||
* @param pad_x The output variable to store the value of pad_x.
|
|
||||||
* @param pad_y The output variable to store the value of pad_y.
|
|
||||||
* @param rewriter The rewriter to notify failure
|
|
||||||
*
|
|
||||||
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
|
|
||||||
*/
|
|
||||||
std::optional<llvm::Twine>
|
|
||||||
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tiles the image tensor by channel.
|
|
||||||
*
|
|
||||||
* This function takes an image tensor and tiles it into smaller tiles based on
|
|
||||||
* the channel dimension. The size of each tile is specified by the tileSize
|
|
||||||
* parameter.
|
|
||||||
*
|
|
||||||
* @param imageTensor The input image tensor (NxCxWxH) to be tiled.
|
|
||||||
* @param tiles The output tiles vector to store the tiled image tensors.
|
|
||||||
* @param tileSize The size of each tile.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating operations.
|
|
||||||
*/
|
|
||||||
void tileImageTensorByChannel(mlir::Value imageTensor,
|
|
||||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
|
|
||||||
size_t tileSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an ImgConcatOp based on the given tiles.
|
|
||||||
*
|
|
||||||
* This function takes a 3-dimensional vector `outputTiles` representing the
|
|
||||||
* tiles to concatenate. The tiles are indexed by [tile][x][y].
|
|
||||||
*
|
|
||||||
* @param outputTiles The tiles to concatenate.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating the
|
|
||||||
* ImgConcatOp.
|
|
||||||
* @param loc The location of the operation.
|
|
||||||
* @param outputType The type of the output tensor.
|
|
||||||
*
|
|
||||||
* @return The created ImgConcatOp.
|
|
||||||
*/
|
|
||||||
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location& loc,
|
|
||||||
mlir::Type outputType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Verifies if the given input coordinates and padding values are within
|
|
||||||
* the bounds of the input tensor.
|
|
||||||
*
|
|
||||||
* @param input_w The width of the input tensor.
|
|
||||||
* @param input_h The height of the input tensor.
|
|
||||||
* @param inX The X-coordinate of the input.
|
|
||||||
* @param inY The Y-coordinate of the input.
|
|
||||||
* @param pad_x The padding value in the X-direction.
|
|
||||||
* @param pad_y The padding value in the Y-direction.
|
|
||||||
* @return LogicalResult Returns success if the coordinates and padding are
|
|
||||||
* within bounds, failure otherwise.
|
|
||||||
*/
|
|
||||||
mlir::LogicalResult
|
|
||||||
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resolves the tiling of the input tensor into smaller tiles.
|
|
||||||
*
|
|
||||||
* This function takes a whole input tensor and tiles it into smaller tiles
|
|
||||||
* using the provided parameters. The resulting tiles are stored in the
|
|
||||||
* `inputTiles` vector.
|
|
||||||
* Input tiles need to be indexed by:
|
|
||||||
* a. Channel Tile
|
|
||||||
* b. Pixel `x` position
|
|
||||||
* c. Pixel `y` position
|
|
||||||
* For example: inputTiles[channelTile][x][y]
|
|
||||||
*
|
|
||||||
* @param wholeInputTensor The whole input tensor to be tiled.
|
|
||||||
* @param inputTiles A vector of vectors of vectors of Values representing the
|
|
||||||
* tiles of the input tensor. The outermost vector represents
|
|
||||||
* the channels, the middle vector represents the rows, and
|
|
||||||
* the innermost vector represents the columns of the tiles.
|
|
||||||
* @param channelTileCount The number of tiles for the `channel` axis.
|
|
||||||
* @param channelTileRest The size of the last channelTile. Set as 0 if tiles
|
|
||||||
* fit exactly
|
|
||||||
* @param input_w The width of the input tensor.
|
|
||||||
* @param input_h The height of the input tensor.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating operations.
|
|
||||||
*
|
|
||||||
* @return std::optional<llvm::Twine> An error message if the input tensor could
|
|
||||||
* not be resolved into tiles.
|
|
||||||
*/
|
|
||||||
std::optional<llvm::Twine>
|
|
||||||
resolveImgInputTiles(mlir::Value wholeInputTensor,
|
|
||||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the boundaries of an image kernel application.
|
|
||||||
*
|
|
||||||
* @param out_pos The position of the output element.
|
|
||||||
* @param input_width The width of the input image.
|
|
||||||
* @param krn_width The width of the kernel.
|
|
||||||
* @param stride The stride value.
|
|
||||||
* @param dilation The dilation value.
|
|
||||||
* @param pad The padding value.
|
|
||||||
* @return A pair of size_t values representing the start and end positions of
|
|
||||||
* the kernel application.
|
|
||||||
*/
|
|
||||||
std::pair<size_t, size_t> kernel_get_start_and_end(
|
|
||||||
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Increment the `operandSegmentSizes` in the WeightedCompute operation
|
|
||||||
* for the `inputs` operand.
|
|
||||||
*
|
|
||||||
* This function increments the size of the `inputs` operand segment in the
|
|
||||||
* `operandSegmentSizes` of the given WeightedCompute operation by the specified
|
|
||||||
* increment. This is necessary when new operands are programmatically added to
|
|
||||||
* the WeightedCompute operation.
|
|
||||||
*
|
|
||||||
* @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes`
|
|
||||||
* is to be incremented.
|
|
||||||
* @param increment The value by which to increment the `inputs` operand segment
|
|
||||||
* size.
|
|
||||||
*/
|
|
||||||
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Finds the result index of the given operation that produces the
|
|
||||||
* specified value.
|
|
||||||
*
|
|
||||||
* This function takes an operation and a value, and returns the index of the
|
|
||||||
* result of the operation that corresponds to the given value.
|
|
||||||
*
|
|
||||||
* @param op Operation whose result index is to be found.
|
|
||||||
* @param v The value for which the result index is to be determined.
|
|
||||||
* @return The index of the result of the operation that produces the specified
|
|
||||||
* value.
|
|
||||||
*/
|
|
||||||
int getResultIndex(mlir::Operation* op, mlir::Value v);
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
}; // namespace onnx_mlir
|
||||||
|
|||||||
@@ -223,9 +223,6 @@ void SpatialToGraphvizPass::runOnOperation() {
|
|||||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
drawConcatOpSubgraph(concatOp, concatNum++);
|
drawConcatOpSubgraph(concatOp, concatNum++);
|
||||||
}
|
}
|
||||||
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
|
|
||||||
drawConcatOpSubgraph(imgConcatOp, concatNum++);
|
|
||||||
}
|
|
||||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
||||||
if (producerOp) {
|
if (producerOp) {
|
||||||
|
|||||||
@@ -45,8 +45,4 @@ createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir
|
|||||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool isAConcatOp(mlir::Operation* op) {
|
|
||||||
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -45,4 +45,5 @@ def spatToPimVVMaxOp : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
|
if (isa<tensor::ConcatOp>(resultUser)) {
|
||||||
auto concatOp = resultUser;
|
auto concatOp = resultUser;
|
||||||
auto concatValue = concatOp->getResult(0);
|
auto concatValue = concatOp->getResult(0);
|
||||||
auto concatUses = concatValue.getUses();
|
auto concatUses = concatValue.getUses();
|
||||||
@@ -368,8 +368,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
};
|
};
|
||||||
|
|
||||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||||
auto outTensorOperand = vmmOp.getOutBuf();
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
||||||
auto resultTensor = vmmOp.getOutRes();
|
auto resultTensor = vmmOp.getOutput();
|
||||||
auto outShape = getTensorShape(outTensorOperand);
|
auto outShape = getTensorShape(outTensorOperand);
|
||||||
assert(isHVectorShape(outShape));
|
assert(isHVectorShape(outShape));
|
||||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||||
@@ -602,9 +602,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
rewriter.modifyOpInPlace(returnOp,
|
rewriter.modifyOpInPlace(returnOp,
|
||||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||||
|
|
||||||
// If the operand is a concatenation operation and the returnOp was the only
|
if (isa<tensor::ConcatOp>(returnOperand)) {
|
||||||
// user of the returnOperand, we can safely remove it
|
|
||||||
if (isAConcatOp(returnOperand)) {
|
|
||||||
auto returnOperandUses = it.value().getUses();
|
auto returnOperandUses = it.value().getUses();
|
||||||
if (rangeLength(returnOperandUses) == 0)
|
if (rangeLength(returnOperandUses) == 0)
|
||||||
rewriter.eraseOp(returnOperand);
|
rewriter.eraseOp(returnOperand);
|
||||||
@@ -632,7 +630,7 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
// user. This means that we need to get the replace the original SendOp with
|
// user. This means that we need to get the replace the original SendOp with
|
||||||
// a BroadcastSendOp
|
// a BroadcastSendOp
|
||||||
rewriter.setInsertionPoint(sendOp);
|
rewriter.setInsertionPoint(sendOp);
|
||||||
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getInput());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,45 +20,12 @@ class PimOp<string mnemonic, list<Trait> traits = []> :
|
|||||||
def PimTensor :
|
def PimTensor :
|
||||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||||
|
|
||||||
// Communication
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Execution
|
||||||
def PimSendOp: PimOp<"send", []> {
|
//===----------------------------------------------------------------------===//
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $src,
|
|
||||||
I32Attr: $size,
|
|
||||||
I32Attr: $targetCoreId
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $dst,
|
|
||||||
I32Attr: $size,
|
|
||||||
I32Attr: $srcCoreId
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $out
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Core
|
|
||||||
|
|
||||||
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||||
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
@@ -72,412 +39,443 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Memory
|
|
||||||
|
|
||||||
def PimConstantOp: PimOp<"constant", []> {
|
|
||||||
let description = [{
|
|
||||||
Allocate a constant value in global memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyAttr: $value,
|
|
||||||
BoolAttr: $shouldAllocate
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $out
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Copy a memory region from host memory into device memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $deviceDst,
|
|
||||||
PimTensor: $hostSrc,
|
|
||||||
I32Attr: $deviceDstOffset,
|
|
||||||
I32Attr: $hostSrcOffset,
|
|
||||||
I32Attr: $size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $deviceDstOut
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDeviceDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Copy a memory region from device memory into host memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $hostDst,
|
|
||||||
PimTensor: $deviceSrc,
|
|
||||||
I32Attr: $hostDstOffset,
|
|
||||||
I32Attr: $deviceSrcOffset,
|
|
||||||
I32Attr: $size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $hostDstOut
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getHostDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Copy a memory region from and to the same memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $dst,
|
|
||||||
PimTensor: $src,
|
|
||||||
I32Attr: $dstOffset,
|
|
||||||
I32Attr: $srcOffset,
|
|
||||||
I32Attr: $size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $dstOut
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Algebra
|
|
||||||
|
|
||||||
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Matrix transpose
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $data,
|
|
||||||
I64ArrayAttr: $perms,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Vector-matrix multiplication: c = a * b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
PimTensor: $vectorInput,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Matrix-vector multiplication: c = a * b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
PimTensor: $vectorInput,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise addition: c = a + b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise subtraction: c = a - b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise multiplication: c = a * b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise max: c = max(a, b)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Dot product: c = dot(a, b)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Apply filters to a tensor
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I64ArrayAttr: $weightIndices,
|
|
||||||
I64ArrayAttr: $xKernelPositions,
|
|
||||||
I64ArrayAttr: $yKernelPositions,
|
|
||||||
PimTensor: $input,
|
|
||||||
PimTensor: $outBuf,
|
|
||||||
PimTensor: $accumBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
|
|
||||||
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Sum all elements into a single one
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Average all elements into a single one
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise ReLU: c = max(a, 0)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise tanh activation
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise sigmoid activation
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||||
let description = [{
|
let summary = "Halt execution of the core";
|
||||||
Halts the execution of the core
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
attr-dict
|
attr-dict
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Communication
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PimSendOp : PimOp<"send", []> {
|
||||||
|
let summary = "Send a tensor to another core";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
I32Attr:$size,
|
||||||
|
I32Attr:$targetCoreId
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Receive a tensor from another core";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$outputBuffer,
|
||||||
|
I32Attr:$size,
|
||||||
|
I32Attr:$sourceCoreId
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region from host memory into device memory";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$deviceTarget,
|
||||||
|
PimTensor:$hostSource,
|
||||||
|
I32Attr:$deviceTargetOffset,
|
||||||
|
I32Attr:$hostSourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getDeviceTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region from device memory into host memory";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$hostTarget,
|
||||||
|
PimTensor:$deviceSource,
|
||||||
|
I32Attr:$hostTargetOffset,
|
||||||
|
I32Attr:$deviceSourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getHostTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region within the same memory space";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$target,
|
||||||
|
PimTensor:$source,
|
||||||
|
I32Attr:$targetOffset,
|
||||||
|
I32Attr:$sourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Math
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Transpose a matrix";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
I64ArrayAttr:$permutation,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Vector-matrix multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Matrix-vector multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise addition: c = a + b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise subtraction: c = a - b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise max: c = max(a, b)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Dot product: c = dot(a, b)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimSumOp : PimOp<"sum", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Reduce all elements to a single value";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Average all elements into a single value";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise ReLU: c = max(a, 0)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise tanh activation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise sigmoid activation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // PIM_DIALECT_H
|
#endif // PIM_DIALECT_H
|
||||||
|
|||||||
@@ -25,19 +25,6 @@ void PimDialect::initialize() {
|
|||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#define POPULATE_DEPENDENCIES(OP_NAME) \
|
|
||||||
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
|
|
||||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
|
||||||
}
|
|
||||||
|
|
||||||
POPULATE_DEPENDENCIES(PimVVDMulOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimSumOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVAvgOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVTanhOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVSigmOp)
|
|
||||||
|
|
||||||
} // namespace pim
|
} // namespace pim
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
|
|||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
.getDstOut();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
struct MemCopyHostToDevOpInterface
|
||||||
@@ -40,26 +40,26 @@ struct MemCopyHostToDevOpInterface
|
|||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
|
||||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
auto hostSource = memCopyHostToDevOp.getHostSource();
|
||||||
|
|
||||||
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
|
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
|
||||||
if (failed(deviceDstOpt))
|
if (failed(deviceTargetOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto deviceDstMemRef = *deviceDstOpt;
|
auto deviceTargetMemRef = *deviceTargetOpt;
|
||||||
|
|
||||||
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
|
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
|
||||||
if (failed(hostSrcOpt))
|
if (failed(hostSourceOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto hostSrcMemRef = *hostSrcOpt;
|
auto hostSourceMemRef = *hostSourceOpt;
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||||
memCopyHostToDevOp,
|
memCopyHostToDevOp,
|
||||||
deviceDstMemRef.getType(),
|
deviceTargetMemRef.getType(),
|
||||||
deviceDstMemRef,
|
deviceTargetMemRef,
|
||||||
hostSrcMemRef,
|
hostSourceMemRef,
|
||||||
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
|
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||||
memCopyHostToDevOp.getHostSrcOffsetAttr(),
|
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||||
memCopyHostToDevOp.getSizeAttr());
|
memCopyHostToDevOp.getSizeAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -73,25 +73,25 @@ struct MemCopyDevToHostOpInterface
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||||
|
|
||||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
auto hostTarget = memCopyDevToHostOp.getHostTarget();
|
||||||
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
|
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
|
||||||
if (failed(globalDstOpt))
|
if (failed(hostTargetOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto globalDstMemRef = *globalDstOpt;
|
auto hostTargetMemRef = *hostTargetOpt;
|
||||||
|
|
||||||
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
|
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
|
||||||
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
|
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
|
||||||
if (failed(localSrcOpt))
|
if (failed(deviceSourceOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto localSrcMemRef = *localSrcOpt;
|
auto deviceSourceMemRef = *deviceSourceOpt;
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||||
memCopyDevToHostOp,
|
memCopyDevToHostOp,
|
||||||
globalDstMemRef.getType(),
|
hostTargetMemRef.getType(),
|
||||||
globalDstMemRef,
|
hostTargetMemRef,
|
||||||
localSrcMemRef,
|
deviceSourceMemRef,
|
||||||
memCopyDevToHostOp.getHostDstOffsetAttr(),
|
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||||
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
|
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||||
memCopyDevToHostOp.getSizeAttr());
|
memCopyDevToHostOp.getSizeAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -109,16 +109,16 @@ struct TransposeOpBufferizeInterface
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto transposeOp = cast<PimTransposeOp>(op);
|
auto transposeOp = cast<PimTransposeOp>(op);
|
||||||
|
|
||||||
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
|
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
|
||||||
if (failed(dataOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
|
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outBufOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||||
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
|
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -132,9 +132,9 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
|||||||
auto vmmOp = cast<PimVMMOp>(op);
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
Value readVal = uRead->get();
|
Value readVal = uRead->get();
|
||||||
Value writeVal = uWrite->get();
|
Value writeVal = uWrite->get();
|
||||||
if (writeVal != vmmOp.getOutBuf())
|
if (writeVal != vmmOp.getOutputBuffer())
|
||||||
return false;
|
return false;
|
||||||
if (readVal == vmmOp.getVectorInput())
|
if (readVal == vmmOp.getInput())
|
||||||
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
@@ -146,16 +146,16 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto vmmOp = cast<PimVMMOp>(op);
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
|
|
||||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
|
||||||
if (failed(vectorInputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
|
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outBufOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -171,16 +171,16 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto mvmOp = cast<PimMVMOp>(op);
|
auto mvmOp = cast<PimMVMOp>(op);
|
||||||
|
|
||||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
|
||||||
if (failed(vectorInputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
|
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outBufOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||||
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -203,22 +203,23 @@ struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<B
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto binaryOp = cast<OpTy>(op);
|
auto binaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
|
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
|
||||||
if (failed(aOpt))
|
if (failed(lhsOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
|
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
|
||||||
if (failed(bOpt))
|
if (failed(rhsOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
|
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outBufOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
|
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
|
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
|
replaceOpWithNewBufferizedOp<OpTy>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -16,4 +16,5 @@ def memrefCopyToPimMemCopyOp : Pat<
|
|||||||
(returnType $dst))
|
(returnType $dst))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|
||||||
#endif // PIM_BUFFERIZATION
|
#endif // PIM_BUFFERIZATION
|
||||||
|
|||||||
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
|||||||
let summary = "Virtual channel type";
|
let summary = "Virtual channel type";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Execution
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||||
let summary = "Compute operation, with constant weights already attached";
|
let summary = "Compute region with attached constant weights";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<SpatTensor>:$weights,
|
Variadic<SpatTensor>:$weights,
|
||||||
@@ -50,6 +54,8 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
|
|||||||
}
|
}
|
||||||
|
|
||||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||||
|
let summary = "Yield results from a compute region";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<SpatTensor>:$outputs
|
Variadic<SpatTensor>:$outputs
|
||||||
);
|
);
|
||||||
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Data movement operations
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||||
|
let summary = "Create a new virtual channel";
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatChannelType:$new_channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
@@ -80,107 +88,73 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||||
|
let summary = "Send a tensor through a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel,
|
SpatChannelType:$channel,
|
||||||
SpatTensor: $data
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||||
|
let summary = "Receive a tensor from a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor: $data
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||||
|
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel,
|
SpatChannelType:$channel,
|
||||||
SpatTensor: $data
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||||
|
let summary = "Receive a tensor from a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor: $data
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
let assemblyFormat = [{
|
||||||
// Math operations
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SpatConstantOp: SpatOp<"constant", []> {
|
|
||||||
let description = [{
|
|
||||||
"Constant value, should be used for weights and biases"
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyAttr: $value,
|
|
||||||
BoolAttr: $shouldAllocate
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor: $out
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Math
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
||||||
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$weightIndex,
|
I32Attr:$weightIndex,
|
||||||
SpatTensor:$vector
|
SpatTensor:$input
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
|
||||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
|
|
||||||
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
SpatTensor:$vector
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
|
||||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def SpatVAddOp: SpatOp<"vadd", []> {
|
|
||||||
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor: $a,
|
|
||||||
SpatTensor: $b
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -190,65 +164,67 @@ def SpatVAddOp: SpatOp<"vadd", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
||||||
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
SpatTensor:$input
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||||
|
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$lhs,
|
||||||
|
SpatTensor:$rhs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||||
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor: $a,
|
SpatTensor:$lhs,
|
||||||
SpatTensor: $b
|
SpatTensor:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
//let hasVerifier = 1;
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVDivOp: SpatOp<"vdiv", []> {
|
|
||||||
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor:$a,
|
|
||||||
SpatTensor:$b
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
//let hasVerifier = 1;
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO: remove
|
|
||||||
def SpatVSDivOp: SpatOp<"vsdiv", []> {
|
|
||||||
|
|
||||||
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor:$dividend,
|
|
||||||
SpatTensor:$divisor
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatSumOp : SpatOp<"sum", []> {
|
def SpatSumOp : SpatOp<"sum", []> {
|
||||||
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
|
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
@@ -257,9 +233,15 @@ def SpatSumOp: SpatOp<"sum", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||||
|
let summary = "Element-wise sigmoid activation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
@@ -267,9 +249,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatReluOp : SpatOp<"relu", []> {
|
def SpatReluOp : SpatOp<"relu", []> {
|
||||||
|
let summary = "Element-wise ReLU activation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
@@ -277,15 +265,18 @@ def SpatReluOp: SpatOp<"relu", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVMaxOp : SpatOp<"vmax", []> {
|
def SpatVMaxOp : SpatOp<"vmax", []> {
|
||||||
|
let summary = "Element-wise max between two tensors";
|
||||||
let summary = "Element-wise max function";
|
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor: $a,
|
SpatTensor:$lhs,
|
||||||
SpatTensor: $b
|
SpatTensor:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -293,62 +284,9 @@ def SpatVMaxOp: SpatOp<"vmax", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
|
||||||
|
|
||||||
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
|
|
||||||
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
|
|
||||||
let description = [{
|
|
||||||
Applies a variable number of crossbar weights to a single large image tensor tile,
|
|
||||||
producing a corresponding output tile. This essentially encapsulates a big for loop
|
|
||||||
over all pixels in the input tile, where each pixel is multiplied by all the weights
|
|
||||||
in the operation.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I64ArrayAttr: $weightIndices,
|
|
||||||
I64ArrayAttr: $xKernelPositions,
|
|
||||||
I64ArrayAttr: $yKernelPositions,
|
|
||||||
SpatTensor: $input
|
|
||||||
);
|
|
||||||
let results = (outs SpatTensor);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$input attr-dict `:` type($input) `->` type(results)
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Other operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
|
||||||
|
|
||||||
let summary = "Concatenate pixel tiles into a single image";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Concatenate pixel tiles into a single image:
|
|
||||||
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
|
|
||||||
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
|
|
||||||
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
|
|
||||||
|
|
||||||
The input tiles should be provided in a specific order:
|
|
||||||
start from the top left pixel,
|
|
||||||
then continue with the pixel on its right,
|
|
||||||
and once you finish the first row of pixels, go to the next row.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<SpatTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
|||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getVector().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
|
|
||||||
/* Two possible accepted shapes:
|
/* Two possible accepted shapes:
|
||||||
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
|
|||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getVector().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
|
|
||||||
/* Accepted shape:
|
/* Accepted shape:
|
||||||
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
|
|||||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatImgConcatOp::verify() {
|
|
||||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
||||||
size_t img_w = getImageWidth(imgShape);
|
|
||||||
size_t img_h = getImageHeight(imgShape);
|
|
||||||
size_t img_c = getImageChannel(imgShape);
|
|
||||||
|
|
||||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
||||||
size_t channelTileRest = img_c % crossbarSize;
|
|
||||||
|
|
||||||
auto operands = getOperands();
|
|
||||||
|
|
||||||
// Check number of operands
|
|
||||||
if (img_w * img_h * channelTiles != operands.size())
|
|
||||||
return emitError("Number of operands does not match output image size");
|
|
||||||
|
|
||||||
// For each output pixel, check that the inputTiles have a correct shape
|
|
||||||
for (size_t x = 0; x < img_w; x++) {
|
|
||||||
for (size_t y = 0; y < img_h; y++) {
|
|
||||||
size_t channel_counts = 0;
|
|
||||||
for (size_t t = 0; t < channelTiles; t++) {
|
|
||||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
|
||||||
if (!inputShape)
|
|
||||||
return emitError("Invalid input type, must be ShapedType");
|
|
||||||
|
|
||||||
// N == W == H == 1
|
|
||||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
|
||||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
|
||||||
|
|
||||||
size_t inputChannels = getImageChannel(inputShape);
|
|
||||||
|
|
||||||
// Check the number of channels in this tile are correct:
|
|
||||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
|
||||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
|
||||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
|
||||||
if (inputChannels != channelTileRest)
|
|
||||||
return emitError("Invalid channel count for last tile of pixel");
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (inputChannels != crossbarSize)
|
|
||||||
return emitError("Invalid channel count for some pixel tile");
|
|
||||||
}
|
|
||||||
|
|
||||||
channel_counts += inputChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (channel_counts != img_c)
|
|
||||||
emitError("Invalid number of channels for some pixel");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatWeightedCompute::verify() {
|
LogicalResult SpatWeightedCompute::verify() {
|
||||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||||
// operand with the same type as the result
|
// operand with the same type as the result
|
||||||
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
|
||||||
auto operands = getOperands();
|
|
||||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
||||||
size_t img_w = getImageWidth(imgShape);
|
|
||||||
size_t img_h = getImageHeight(imgShape);
|
|
||||||
size_t img_c = getImageChannel(imgShape);
|
|
||||||
|
|
||||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
||||||
|
|
||||||
assert(tile < channelTiles);
|
|
||||||
assert(x < img_w);
|
|
||||||
assert(y < img_h);
|
|
||||||
|
|
||||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
|||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
.getDstOut();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||||
@@ -202,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
|||||||
|
|
||||||
memrefOperands.push_back(outputTensor);
|
memrefOperands.push_back(outputTensor);
|
||||||
|
|
||||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -243,7 +243,7 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
|||||||
cast<OpTy>(op).getWeightIndexAttr(),
|
cast<OpTy>(op).getWeightIndexAttr(),
|
||||||
memrefOperand,
|
memrefOperand,
|
||||||
outputTensor)
|
outputTensor)
|
||||||
.getOutRes();
|
.getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ struct ChannelReceiveOpInterface
|
|||||||
outputTensor,
|
outputTensor,
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||||
.getOut();
|
.getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -403,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(outputSize));
|
rewriter.getI32IntegerAttr(outputSize));
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -490,84 +490,6 @@ struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, S
|
|||||||
|
|
||||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||||
|
|
||||||
// Create a new bufferizable op interface for the apply filters operation.
|
|
||||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
|
||||||
|
|
||||||
// One operand ($input) is read from. All other inputs are only written to.
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
|
|
||||||
// Operand 0: $input
|
|
||||||
// Operand 1: $outBuf
|
|
||||||
// Operand 2: $accumBuf
|
|
||||||
return opOperand.getOperandNumber() == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// One input ($accumBuf) is written to. All other inputs are only read.
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
|
|
||||||
// Operand 0: $input
|
|
||||||
// Operand 1: $outBuf
|
|
||||||
// Operand 2: $accumBuf
|
|
||||||
return opOperand.getOperandNumber() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No operands are aliased with any other operands.
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bufferize the operation.
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
|
|
||||||
// Get the input tensor buffer.
|
|
||||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
|
||||||
|
|
||||||
if (failed(inputBuffer))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Create a new buffer for the output tensor.
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
// Create a new buffer for the accumulation buffer.
|
|
||||||
// To do this, create a new allocation operation. Size must be axbx1x1,
|
|
||||||
// where axbxcxd is the size of the output tensor. Since the shape is
|
|
||||||
// different, we can't immediately use createEmptyFromType, we first need to
|
|
||||||
// create the shape of the accumulation buffer.
|
|
||||||
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
|
|
||||||
|
|
||||||
// Set the last two dimensions to 1.
|
|
||||||
accumShape[accumShape.size() - 1] = 1;
|
|
||||||
accumShape[accumShape.size() - 2] = 1;
|
|
||||||
|
|
||||||
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
|
|
||||||
|
|
||||||
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
// Bufferize the operation.
|
|
||||||
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
|
|
||||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
|
||||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
|
||||||
|
|
||||||
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
weightIndices,
|
|
||||||
xKernelPositions,
|
|
||||||
yKernelPositions,
|
|
||||||
*inputBuffer,
|
|
||||||
outputTensor,
|
|
||||||
accumBuffer);
|
|
||||||
|
|
||||||
// Replace the operation with the bufferized value.
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||||
@@ -580,7 +502,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|||||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||||
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -247,11 +247,11 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||||
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutput().getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!sourceGetGlobal)
|
if (!sourceGetGlobal)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -268,8 +268,8 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t> perms;
|
SmallVector<int64_t> perms;
|
||||||
perms.reserve(transposeOp.getPerms().size());
|
perms.reserve(transposeOp.getPermutation().size());
|
||||||
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
for (IntegerAttr attr : transposeOp.getPermutation().getAsRange<IntegerAttr>())
|
||||||
perms.push_back(attr.getInt());
|
perms.push_back(attr.getInt());
|
||||||
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||||
if (failed(transposedAttr))
|
if (failed(transposedAttr))
|
||||||
@@ -389,18 +389,18 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||||
if (!allocOp)
|
if (!allocOp)
|
||||||
return failure();
|
return failure();
|
||||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
if (!allocType || !allocType.hasStaticShape())
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||||
|
|
||||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||||
if (!moduleOp)
|
if (!moduleOp)
|
||||||
|
|||||||
@@ -89,10 +89,10 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
|||||||
|
|
||||||
auto status =
|
auto status =
|
||||||
rewriteSubviewCopyLikeOp(copyOp,
|
rewriteSubviewCopyLikeOp(copyOp,
|
||||||
copyOp.getDst(),
|
copyOp.getTarget(),
|
||||||
copyOp.getSrc(),
|
copyOp.getSource(),
|
||||||
copyOp.getDstOffset(),
|
copyOp.getTargetOffset(),
|
||||||
copyOp.getSrcOffset(),
|
copyOp.getSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](MemRefType resultType,
|
[&](MemRefType resultType,
|
||||||
@@ -114,7 +114,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
|||||||
if (failed(status))
|
if (failed(status))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -125,10 +125,10 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
|||||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
auto status =
|
auto status =
|
||||||
rewriteSubviewCopyLikeOp(copyOp,
|
rewriteSubviewCopyLikeOp(copyOp,
|
||||||
copyOp.getDeviceDst(),
|
copyOp.getDeviceTarget(),
|
||||||
copyOp.getHostSrc(),
|
copyOp.getHostSource(),
|
||||||
copyOp.getDeviceDstOffset(),
|
copyOp.getDeviceTargetOffset(),
|
||||||
copyOp.getHostSrcOffset(),
|
copyOp.getHostSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](MemRefType resultType,
|
[&](MemRefType resultType,
|
||||||
@@ -150,7 +150,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
|||||||
if (failed(status))
|
if (failed(status))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user