big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
+1 -1
View File
@@ -228,7 +228,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
continue;
}
if (mlir::isa<onnx_mlir::pim::PimEmptyManyOp, mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
+88 -189
View File
@@ -42,6 +42,79 @@ static size_t getValueSizeInBytes(mlir::Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
static SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
static bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
static FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
@@ -97,11 +170,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) {
if (!emptyManyOp->getParentOfType<pim::PimCoreOp>() && !emptyManyOp->getParentOfType<pim::PimCoreBatchOp>())
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory();
@@ -111,10 +179,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
op->walk([&](pim::PimEmptyManyOp emptyManyOp) {
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory();
}
@@ -369,13 +433,6 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp,
const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] :
llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
@@ -388,11 +445,6 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds()))
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
@@ -400,20 +452,6 @@ void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const St
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp,
const StaticValueKnowledge& knowledge) const {
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
size_t elementSize = inputType.getElementTypeBitWidth() / 8;
size_t rowSizeInBytes = static_cast<size_t>(inputType.getDimSize(1)) * elementSize;
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
emitMemCopyOp(
"lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len");
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
@@ -742,23 +780,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto sendManyBatchOp = dyn_cast<pim::PimSendManyBatchOp>(op)) {
SmallVector<int32_t> laneTargetCoreIds;
laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, sendManyBatchOp.getInputs().size()))
laneTargetCoreIds.push_back(
sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (mlir::Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyOp::create(
builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
@@ -771,29 +792,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto receiveManyBatchOp = dyn_cast<pim::PimReceiveManyBatchOp>(op)) {
SmallVector<int32_t> laneSourceCoreIds;
laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, receiveManyBatchOp.getOutputs().size()))
laneSourceCoreIds.push_back(
receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedOutputBuffers;
mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size());
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] :
llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
mapper.map(originalOutput, scalarOutput);
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
@@ -912,18 +910,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
@@ -954,8 +946,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else if (isa<pim::PimEmptyManyOp>(op))
return success();
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
@@ -967,84 +957,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
/// Write crossbar weight matrices as padded binary files for a single core.
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
StringRef coreWeightsDirPath,
json::Array& xbarsPerGroup) {
int64_t xbarSize = crossbarSize.getValue();
std::error_code errorCode;
size_t weightIndex = 0;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
return CompilerSuccess;
}
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
@@ -1079,45 +991,31 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(!getGlobalOp && "Weight is not from a memref.get_global");
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
assert(!globalOp && "Could not find memref.global");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
assert(!initialValue && "memref.global has no initial value");
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
assert(!denseAttr && "memref.global initial value is not dense");
}
if (mapGlobalOpToFileName.contains(globalOp)) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreId].insert(weightToFile);
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
@@ -1132,8 +1030,8 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
@@ -1144,6 +1042,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
-5
View File
@@ -24,8 +24,6 @@ class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
@@ -117,12 +115,9 @@ public:
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
@@ -116,7 +116,7 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
@@ -150,40 +150,33 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(sendManyOp);
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendManyOp::create(
rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs());
rewriter.eraseOp(sendManyOp);
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
PimSendOp::create(rewriter,
sendManyOp.getLoc(),
input,
getTensorSizeInBytesAttr(rewriter, input),
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)));
}
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) {
SmallVector<Type> tensorTypes;
tensorTypes.reserve(outputTypes.size());
for (Type outputType : outputTypes)
tensorTypes.push_back(outputType);
auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes));
return SmallVector<Value>(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end());
rewriter.eraseOp(sendManyOp);
}
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(receiveManyOp);
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
auto receiveMany = PimReceiveManyOp::create(rewriter,
SmallVector<Value> replacements;
replacements.reserve(receiveManyOp.getNumResults());
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
auto outputType = cast<ShapedType>(output.getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult();
replacements.push_back(
PimReceiveOp::create(rewriter,
receiveManyOp.getLoc(),
receiveManyOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs());
output.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)))
.getOutput());
}
rewriter.replaceOp(receiveManyOp, replacements);
}
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
@@ -198,8 +191,17 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyBatchOp::create(
rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs));
for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) {
SmallVector<int32_t> laneTargetCoreIds;
laneTargetCoreIds.reserve(laneCount);
for (int32_t lane = 0; lane < laneCount; ++lane)
laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]);
pim::PimSendBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
input,
getTensorSizeInBytesAttr(rewriter, input),
rewriter.getDenseI32ArrayAttr(laneTargetCoreIds));
}
}
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
@@ -210,29 +212,44 @@ static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes());
auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter,
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
auto outputType = cast<ShapedType>(output.getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult();
SmallVector<int32_t> laneSourceCoreIds;
laneSourceCoreIds.reserve(laneCount);
for (int32_t lane = 0; lane < laneCount; ++lane)
laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]);
auto received = pim::PimReceiveBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs()))
output.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getDenseI32ArrayAttr(laneSourceCoreIds))
.getOutput();
mapper.map(output, received);
}
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(extractRowsOp);
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes());
auto extractRows = pim::PimExtractRowsOp::create(rewriter,
extractRowsOp.getLoc(),
extractRowsOp.getResultTypes(),
extractRowsOp.getInput(),
ValueRange(outputBuffers));
rewriter.replaceOp(extractRowsOp, extractRows.getOutputs());
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(extractRowsOp, replacements);
}
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
@@ -258,14 +275,26 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
rewriter.setInsertionPoint(mapOp);
auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs());
rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin());
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<pim::PimYieldOp>(yieldOp, yieldOp.getOutputs());
rewriter.replaceOp(mapOp, pimMap.getOutputs());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
for (Operation& bodyOp : body.without_terminator()) {
Operation* cloned = rewriter.clone(bodyOp, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
}
rewriter.replaceOp(mapOp, replacements);
}
}
@@ -295,7 +324,7 @@ static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigne
}
static Value createPackedExtractRowsSlice(
pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
@@ -332,14 +361,17 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
if (!getContiguousOpResults(values, owner, startIndex))
return {};
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
return {};
}
static Value createPackedReceiveTensor(
pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp,
unsigned startIndex,
unsigned count,
IRRewriter& rewriter,
Location loc) {
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
return {};
@@ -351,15 +383,15 @@ static Value createPackedReceiveTensor(
sourceCoreIds.reserve(count);
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
for (unsigned index = 0; index < count; ++index)
sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]);
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index]));
return pim::PimReceiveTensorOp::create(
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
}
static Value
createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
static Value createPackedMapTensor(
spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
if (!packedInput)
return {};
@@ -416,7 +448,7 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count,
rewriter.setInsertionPointAfter(cloned);
}
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
int64_t outputRowsPerValue = outputType.getDimSize(0);
@@ -446,9 +478,9 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count,
return loop.getResult(0);
}
static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<pim::PimSendManyOp> sendManyOps;
funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
for (auto sendManyOp : sendManyOps) {
if (sendManyOp.getInputs().empty())
continue;
@@ -458,12 +490,17 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
if (!packedInput)
continue;
pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr());
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(
rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendManyOp);
}
SmallVector<pim::PimConcatOp> concatOps;
funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); });
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
for (auto concatOp : concatOps) {
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
continue;
@@ -494,11 +531,11 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
unsigned count = endIndex - index;
Value packedInput;
if (auto mapOp = dyn_cast<pim::PimMapOp>(owner))
if (auto mapOp = dyn_cast<spatial::SpatMapOp>(owner))
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(owner))
else if (auto receiveManyOp = dyn_cast<spatial::SpatChannelReceiveManyOp>(owner))
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
if (packedInput) {
@@ -516,12 +553,14 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
if (!changed)
continue;
auto newConcat = pim::PimConcatOp::create(rewriter,
auto newConcat = pim::PimConcatOp::create(
rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
concatOp.getOutputBuffer());
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
@@ -533,10 +572,9 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
if (op->use_empty())
rewriter.eraseOp(op);
};
eraseUnusedOps(pim::PimMapOp {});
eraseUnusedOps(pim::PimReceiveManyOp {});
eraseUnusedOps(pim::PimExtractRowsOp {});
eraseUnusedOps(pim::PimEmptyManyOp {});
eraseUnusedOps(spatial::SpatMapOp {});
eraseUnusedOps(spatial::SpatChannelReceiveManyOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
@@ -617,6 +655,7 @@ struct ConcatReturnUseInfo {
size_t returnIndex;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> concatShape;
SmallVector<Operation*> concatChain;
SmallVector<Operation*> helperChain;
};
@@ -669,6 +708,8 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation* op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getOutput();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
@@ -676,6 +717,8 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getAxis();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
@@ -683,11 +726,14 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatOperands = [](Operation* op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getInputs();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
if (rangeLength(uses) != 1
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
@@ -696,10 +742,12 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
SmallVector<Operation*> concatChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isa<tensor::ConcatOp, pim::PimConcatOp>(currentUser)) {
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
concatChain.push_back(currentUser);
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
@@ -749,6 +797,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
currentValue.getUses().begin()->getOperandNumber(),
std::move(sliceOffsets),
std::move(concatShape),
std::move(concatChain),
std::move(helperChain),
};
}
@@ -918,11 +967,6 @@ void SpatialToPimPass::runOnOperation() {
return;
}
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); });
for (auto concatOp : concatOps)
lowerConcat(concatOp, rewriter);
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
runOnComputeOp(computeOp, rewriter);
@@ -933,6 +977,7 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeBatchOp(computeBatchOp, rewriter);
}
compactSpatialTensorGroups(funcOp, rewriter);
lowerMapOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
@@ -1036,6 +1081,8 @@ void SpatialToPimPass::runOnOperation() {
assert(false && "tracked op removal reached a cycle or missed dependency");
}
compactSpatialTensorGroups(funcOp, rewriter);
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
for (auto concatOp : remainingConcatOps)
@@ -1066,8 +1113,6 @@ void SpatialToPimPass::runOnOperation() {
for (auto extractRowsOp : remainingExtractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
compactPimTensorGroups(funcOp, rewriter);
// Dump to file for debug
bool hasSpatialOps = false;
moduleOp.walk([&](Operation* op) {
@@ -1170,6 +1215,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
@@ -1481,13 +1528,15 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
Value currentReturnValue = returnValue;
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
outputTensors.push_back(
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
}
else {
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(currentReturnValue.getType());
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputName = "output_" + std::to_string(index);
@@ -1565,7 +1614,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isChannelUseChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
@@ -1593,6 +1642,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
-126
View File
@@ -71,38 +71,6 @@ def PimYieldOp : PimOp<"yield", [Terminator]> {
let hasCustomAssemblyFormat = 1;
}
def PimMapOp : PimOp<"map", [SingleBlock]> {
let summary = "Apply the same lane-local region to many independent tensors";
let arguments = (ins
Variadic<PimTensor>:$inputs
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Tensor Utilities
//===----------------------------------------------------------------------===//
def PimEmptyManyOp : PimOp<"empty_many", []> {
let summary = "Create many identical empty tensors";
let results = (outs
Variadic<PimTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
@@ -121,18 +89,6 @@ def PimSendOp : PimOp<"send", []> {
}];
}
def PimSendManyOp : PimOp<"send_many", []> {
let summary = "Send multiple tensors to target cores";
let arguments = (ins
DenseI32ArrayAttr:$targetCoreIds,
Variadic<PimTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendTensorOp : PimOp<"send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor to target cores";
@@ -157,18 +113,6 @@ def PimSendBatchOp : PimOp<"send_batch", []> {
let hasCustomAssemblyFormat = 1;
}
def PimSendManyBatchOp : PimOp<"send_many_batch", []> {
let summary = "Send multiple per-lane tensors to target cores from a batched core";
let arguments = (ins
DenseI32ArrayAttr:$targetCoreIds,
Variadic<PimTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
@@ -193,28 +137,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
}
def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
let summary = "Receive multiple tensors from source cores";
let arguments = (ins
Variadic<PimTensor>:$outputBuffers,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks from source cores into one tensor";
@@ -259,28 +181,6 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let hasCustomAssemblyFormat = 1;
}
def PimReceiveManyBatchOp : PimOp<"receive_many_batch", [DestinationStyleOpInterface]> {
let summary = "Receive multiple per-lane tensors from source cores into a batched core";
let arguments = (ins
Variadic<PimTensor>:$outputBuffers,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
@@ -385,32 +285,6 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
}];
}
//===----------------------------------------------------------------------===//
// Tensor utilities
//===----------------------------------------------------------------------===//
def PimExtractRowsOp : PimOp<"extract_rows", [DestinationStyleOpInterface]> {
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
let arguments = (ins
PimTensor:$input,
Variadic<PimTensor>:$outputBuffers
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimConcatOp : PimOp<"concat", [DestinationStyleOpInterface]> {
let summary = "Concatenate tensors";
-230
View File
@@ -147,69 +147,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void PimEmptyManyOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getOutputs().front().getType());
printer << " x" << getOutputs().size();
}
ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
int64_t resultCount = 0;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)
|| parser.parseKeyword("x") || parser.parseInteger(resultCount))
return failure();
if (resultCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive");
SmallVector<Type> resultTypes(resultCount, outputType);
result.addTypes(resultTypes);
return success();
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -237,36 +174,6 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -294,72 +201,6 @@ ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result)
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimReceiveManyOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
@@ -434,77 +275,6 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result
return success();
}
void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (parser.resolveOperand(input, inputType, result.operands)
|| parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());
-180
View File
@@ -13,12 +13,6 @@ namespace pim {
namespace {
static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
if (coreIds.size() != valueCount)
return op->emitError("core id metadata length must match the number of values");
return success();
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
@@ -33,28 +27,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
return success();
}
static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
Type firstType = types.front();
auto firstShapedType = dyn_cast<ShapedType>(firstType);
bool firstIsTensor = isa<RankedTensorType>(firstType);
bool firstIsMemRef = isa<MemRefType>(firstType);
for (Type type : types.drop_front())
if (type != firstType) {
auto shapedType = dyn_cast<ShapedType>(type);
if (!firstShapedType || !shapedType)
return op->emitError() << kind << " values must all have the same type";
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
return op->emitError() << kind << " values must all use the same shaped container kind";
if (firstShapedType.getElementType() != shapedType.getElementType()
|| firstShapedType.getShape() != shapedType.getShape())
return op->emitError() << kind << " values must all have the same shape and element type";
}
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
@@ -74,109 +46,12 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
return success();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return failure();
return coreBatchOp.getLaneCount();
}
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside pim.core_batch");
if (coreIds.size() != valueCount * static_cast<size_t>(*laneCount))
return op->emitError("core id metadata length must match the number of values times parent laneCount");
return success();
}
} // namespace
LogicalResult PimEmptyManyOp::verify() {
if (getOutputs().empty())
return emitError("must produce at least one output");
Type firstType = getOutputs().front().getType();
auto firstShapedType = dyn_cast<ShapedType>(firstType);
if (!firstShapedType || !firstShapedType.hasRank())
return emitError("outputs must all be ranked shaped types");
for (Value output : getOutputs().drop_front())
if (output.getType() != firstType)
return emitError("outputs must all have the same type");
return success();
}
LogicalResult PimMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (failed(verifyCompatibleShapedTypes(
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with pim.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (failed(verifyCompatibleShapedTypes(
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
return emitError("body yield type must match output type");
return success();
}
LogicalResult PimSendManyOp::verify() {
if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendManyBatchOp::verify() {
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch");
}
LogicalResult PimReceiveManyOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
@@ -185,61 +60,6 @@ LogicalResult PimReceiveTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveManyBatchOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimExtractRowsOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
auto inputType = dyn_cast<ShapedType>(getInput().getType());
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
return emitError("input must be a rank-2 shaped type");
int64_t numRows = inputType.getShape()[0];
int64_t numCols = inputType.getShape()[1];
Type elementType = inputType.getElementType();
if (numRows >= 0 && static_cast<int64_t>(getOutputs().size()) != numRows)
return emitError("number of outputs must match the number of input rows");
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) {
if (failed(verifyCompatibleShapedTypes(
getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match")))
return failure();
auto outputType = dyn_cast<ShapedType>(output.getType());
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
return emitError("outputs must all be rank-2 shaped types");
if (!haveSameShapedContainerKind(getInput().getType(), output.getType()))
return emitError("outputs must use the same shaped container kind as the input");
if (outputType.getElementType() != elementType)
return emitError("output element types must match input element type");
auto outputShape = outputType.getShape();
if (outputShape[0] != 1)
return emitError("each output must have exactly one row");
if (numCols >= 0 && outputShape[1] != numCols)
return emitError("output column count must match input column count");
}
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
@@ -180,39 +180,6 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<Receive
}
};
struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveManyOpInterface, PimReceiveManyOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveManyOp>(op);
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimReceiveManyOp::create(rewriter,
receiveOp.getLoc(),
TypeRange(resultTypes),
ValueRange(outputBuffers),
receiveOp.getSourceCoreIdsAttr());
rewriter.replaceOp(receiveOp, newOp.getOutputs());
return success();
}
};
struct ReceiveTensorOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
@@ -234,77 +201,6 @@ struct ReceiveTensorOpInterface
}
};
struct ReceiveManyBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveManyBatchOpInterface, PimReceiveManyBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveManyBatchOp>(op);
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimReceiveManyBatchOp::create(rewriter,
receiveOp.getLoc(),
TypeRange(resultTypes),
ValueRange(outputBuffers),
receiveOp.getSourceCoreIdsAttr());
rewriter.replaceOp(receiveOp, newOp.getOutputs());
return success();
}
};
struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractRowsOpInterface, PimExtractRowsOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto extractRowsOp = cast<PimExtractRowsOp>(op);
auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
outputBuffers.reserve(extractRowsOp.getOutputBuffers().size());
resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimExtractRowsOp::create(rewriter,
extractRowsOp.getLoc(),
TypeRange(resultTypes),
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
ValueRange(outputBuffers));
rewriter.replaceOp(extractRowsOp, newOp.getOutputs());
return success();
}
};
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -334,31 +230,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
}
};
struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOpInterface, PimEmptyManyOp> {
bool bufferizesToAllocation(Operation* op, Value value) const { return true; }
bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const {
return false;
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto emptyManyOp = cast<PimEmptyManyOp>(op);
SmallVector<Type> resultTypes;
resultTypes.reserve(emptyManyOp.getOutputs().size());
for (Value output : emptyManyOp.getOutputs()) {
auto shapedType = cast<ShapedType>(output.getType());
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
}
replaceOpWithNewBufferizedOp<PimEmptyManyOp>(rewriter, emptyManyOp, TypeRange(resultTypes));
return success();
}
};
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -383,7 +254,7 @@ struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensor
}
};
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
@@ -392,75 +263,93 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
return {};
}
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|| mapOp.getInputs().empty())
return {};
return {
{&mapOp->getOpOperand(0), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|| mapOp.getInputs().empty())
BufferizationState& state) const {
auto sendOp = cast<PimSendOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
if (inputType)
return inputType;
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdAttr());
return success();
}
};
auto shapedType = cast<ShapedType>(mapOp.getInputs().front().getType());
return BufferLikeType(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mapOp = cast<PimMapOp>(op);
SmallVector<Value> inputs;
SmallVector<Type> resultTypes;
inputs.reserve(mapOp.getInputs().size());
resultTypes.reserve(mapOp.getOutputs().size());
for (Value input : mapOp.getInputs()) {
if (isa<TensorType>(input.getType())) {
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
auto sendOp = cast<PimSendBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(*inputOpt);
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto coreOp = cast<PimCoreOp>(op);
bool alreadyBufferized =
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
if (alreadyBufferized)
return success();
SmallVector<Value> weights;
weights.reserve(coreOp.getWeights().size());
for (Value weight : coreOp.getWeights()) {
if (isa<TensorType>(weight.getType())) {
auto weightOpt = getBufferOrValue(rewriter, weight, options, state);
if (failed(weightOpt))
return failure();
weights.push_back(*weightOpt);
}
else {
inputs.push_back(input);
weights.push_back(weight);
}
}
for (Value output : mapOp.getOutputs()) {
auto shapedType = cast<ShapedType>(output.getType());
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
}
rewriter.setInsertionPoint(mapOp);
auto newOp = PimMapOp::create(rewriter, mapOp.getLoc(), TypeRange(resultTypes), ValueRange(inputs));
rewriter.inlineRegionBefore(mapOp.getBody(), newOp.getBody(), newOp.getBody().begin());
rewriter.setInsertionPoint(coreOp);
auto newOp = PimCoreOp::create(rewriter, coreOp.getLoc(), ValueRange(weights), coreOp.getCoreIdAttr());
rewriter.inlineRegionBefore(coreOp.getBody(), newOp.getBody(), newOp.getBody().begin());
for (Block& block : newOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
return failure();
rewriter.replaceOp(mapOp, newOp.getOutputs());
rewriter.eraseOp(coreOp);
return success();
}
};
@@ -730,16 +619,14 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimEmptyManyOp::attachInterface<EmptyManyOpInterface>(*ctx);
PimMapOp::attachInterface<MapOpInterface>(*ctx);
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
@@ -1,10 +1,9 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Threading.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -47,53 +46,18 @@ private:
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
// Refactor this into a function
{
auto funcOp = *getPimEntryFunc(moduleOp);
SmallVector<Operation*> coreOps;
funcOp->walk<WalkOrder::PreOrder>([&](Operation* op) {
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
coreOps.push_back(op);
});
MLIRContext* ctx = moduleOp.getContext();
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) {
// Again, allocate state LOCALLY per thread/function
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
if (isa<pim::PimCoreBatchOp>(coreOp))
options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; });
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
coreOp->emitError("Failed to bufferize PIM and Spatial ops");
return failure();
}
return success();
});
if (failed(result)) {
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
signalPassFailure();
}
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
if (llvm::isa_and_present<pim::PimCoreOp, pim::PimCoreBatchOp>(toTensorOp->getParentOp()))
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
});
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
options.opFilter.denyOperation([](Operation* op) {
return op->getParentOfType<pim::PimCoreOp>() || op->getParentOfType<pim::PimCoreBatchOp>();
});
options.bufferizeFunctionBoundaries = true;
options.setFunctionBoundaryTypeConversion(bufferization::LayoutMapOption::IdentityLayoutMap);
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
return;
}
MLIRContext* ctx = moduleOp.getContext();
@@ -119,30 +83,6 @@ void PimBufferizationPass::runOnOperation() {
return;
}
// Remove toTensor operations: leave memrefs instead
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
toTensorOp.erase();
});
// Change main function return types from tensors to memrefs
func::FuncOp funcOp;
for (Operation& op : moduleOp.getBody()->getOperations())
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
break;
auto oldFuncType = funcOp.getFunctionType();
SmallVector<Type> newResults;
bool changed = false;
for (Type type : oldFuncType.getResults())
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
changed = true;
}
else
newResults.push_back(type);
if (changed)
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
@@ -1,13 +1,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -34,35 +33,6 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<pim::PimMapOp> mapOps;
funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
}
rewriter.replaceOp(mapOp, replacements);
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
@@ -80,8 +50,6 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
if (funcOp.isExternal())
continue;
expandPimMapOps(funcOp, rewriter);
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
@@ -150,38 +118,11 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimExtractRowsOp, pim::PimConcatOp>(op))
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op)) {
auto inputType = dyn_cast<ShapedType>(extractRowsOp.getInput().getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) {
extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input");
hasFailure = true;
continue;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacementRows;
replacementRows.reserve(extractRowsOp.getOutputs().size());
for (auto rowIndex : llvm::seq<size_t>(0, extractRowsOp.getOutputs().size())) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacementRows.push_back(memref::SubViewOp::create(
rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows));
extractRowsOp->erase();
continue;
}
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
+65 -11
View File
@@ -18,7 +18,6 @@ namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp,
pim::PimEmptyManyOp,
memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
@@ -37,12 +36,24 @@ static bool isBaseAddressableValue(Value value) {
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(defOp))
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
return true;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
value = expand.getSrc();
continue;
}
return false;
}
}
@@ -52,7 +63,38 @@ static bool isCodegenAddressableValue(Value value) {
if (failed(resolvedAddress))
return false;
return isa<BlockArgument>(resolvedAddress->base)
|| isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
auto allStaticSubviewParts = [](memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
};
while (true) {
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)) {
auto moduleOp = getGlobalOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant() && globalOp.getInitialValue()
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
}
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return false;
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
return false;
}
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
@@ -125,13 +167,17 @@ private:
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
if (!getGlobalOp && !isConstantGlobalView(weight)) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
<< " must be materialized as a constant memref.global or a static view of one before JSON "
"codegen";
hasFailure = true;
continue;
}
if (!getGlobalOp)
continue;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
@@ -185,7 +231,7 @@ private:
continue;
}
if (!isa<pim::PimEmptyManyOp, memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
@@ -197,7 +243,7 @@ private:
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
return verifyAddressOnlyBase(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
@@ -221,6 +267,14 @@ private:
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
return failure();
}
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
if (isBaseAddressableValue(source))
return success();
op->emitOpError("depends on a value that is not backed by addressable storage");
return failure();
}
};
} // namespace
-10
View File
@@ -1,10 +0,0 @@
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
AnalisiDCP
NuovoPasso che inserische le send e recive e gestisce gli input
Passo che fa il merge
Probabilmente questo rompera' gli input e come venivano gestiti prima.