better MaterializeMergeSchedule.cpp with %lane indexed batch computes

support for tensors of index values
This commit is contained in:
NiccoloN
2026-05-22 21:52:28 +02:00
parent 495186503c
commit c77ffa9c56
20 changed files with 398 additions and 300 deletions
+1 -1
View File
@@ -264,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
return mlir::failure(); return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge); value = resolveAlias(subviewOp.getSource(), knowledge);
continue; continue;
} }
+25
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements; return numElements;
} }
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
+11
View File
@@ -1,8 +1,13 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape); llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape); int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
+18 -28
View File
@@ -41,23 +41,10 @@ using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm; using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); size_t allocSize = getShapedTypeSizeInBytes(type);
MemEntry memEntry = {0, allocSize}; MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
@@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const { const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge); size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size(); size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
/ receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds())) for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
} }
@@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge); size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size(); size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
/ sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds())) for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
} }
@@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
int64_t axis = concatOp.getAxis(); int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge); size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1; size_t outerCount = 1;
@@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 1; instruction.r2OrImm = 1;
instruction.generic1 = 1; instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vrelu; instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vtanh; instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vsigm; instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
instruction.opcode = pim_binary::Opcode::vsoftmax; instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput())); instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
auto srcType = cast<ShapedType>(transposeOp.getInput().getType()); auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape(); auto srcShape = srcType.getShape();
size_t rank = srcShape.size(); size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
size_t totalElements = srcType.getNumElements(); size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i]. // Read permutation. Destination dim i corresponds to source dim perm[i].
+1 -1
View File
@@ -208,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
int64_t numCols = shape[1]; int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
@@ -427,11 +427,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto inputArg = computeOp.getInputArgument(aHSliceId); auto inputArg = computeOp.getInputArgument(aHSliceId);
if (!weightArg || !inputArg) if (!weightArg || !inputArg)
return failure(); return failure();
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
gemmLoc,
currOutHSliceType,
*weightArg,
*inputArg));
} }
if (vmmOutputs.empty()) { if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
@@ -121,7 +121,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice, tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType, ShapedType destinationType,
IRMapping& mapper) { IRMapping& mapper) {
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8; int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides(destinationType.getRank(), 1); SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape(); ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim) for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue; return returnValue;
} }
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
} }
@@ -20,8 +20,6 @@ namespace onnx_mlir {
*/ */
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T> template <class T>
@@ -433,7 +433,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
markOpToRemove(op); markOpToRemove(op);
auto storedType = cast<ShapedType>(currentStoredValue.getType()); auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
if (auto storedOp = currentStoredValue.getDefiningOp()) if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp); rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
@@ -455,7 +455,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
rewriter.setInsertionPointAfterValue(storedValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter, emitHostCopy(rewriter,
@@ -471,7 +471,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
} }
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
for (Operation* concatOp : concatReturnUse->concatChain) for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(concatOp); markOpToRemove(concatOp);
@@ -325,9 +325,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType()); auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat()) if (!hasByteSizedElementType(elementType))
return; return;
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = getElementTypeSizeInBytes(elementType);
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
return PimMemCopyOp::create(rewriter, return PimMemCopyOp::create(rewriter,
loc, loc,
@@ -1,9 +1,10 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType()); auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
return builder.getI32IntegerAttr(sizeInBytes); return builder.getI32IntegerAttr(sizeInBytes);
} }
@@ -9,6 +9,7 @@
#include <limits> #include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir; using namespace mlir;
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
} }
static bool isCandidateAllocType(MemRefType type) { static bool isCandidateAllocType(MemRefType type) {
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0; return type && type.hasStaticShape() && type.getLayout().isIdentity()
&& hasByteSizedElementType(type.getElementType());
} }
static uint64_t getTypeSizeBytes(MemRefType type) { static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
} }
static FailureOr<uint64_t> static FailureOr<uint64_t>
+15 -8
View File
@@ -34,7 +34,9 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
} // namespace } // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); } std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), idx);
}
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) { std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), getWeights().size() + idx); return getBatchBodyArgument(getBody(), getWeights().size() + idx);
@@ -74,11 +76,13 @@ SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Locat
resultTypes.insert(resultTypes.begin() + idx, type); resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs()); auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs()); newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes( setComputeOperandSegmentSizes(newCompute.getOperation(),
newCompute.getOperation(), static_cast<int32_t>(newCompute.getWeights().size()), static_cast<int32_t>(newCompute.getInputs().size())); static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end()); rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation()); rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute); return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
} }
@@ -110,7 +114,8 @@ std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
} }
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size(); unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size(); unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight}); getOperation()->insertOperands(idx, ValueRange {weight});
@@ -145,8 +150,9 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
auto newBatch = auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs()); SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs()); newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes( setComputeOperandSegmentSizes(newBatch.getOperation(),
newBatch.getOperation(), static_cast<int32_t>(newBatch.getWeights().size()), static_cast<int32_t>(newBatch.getInputs().size())); static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end()); rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) { if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch); rewriter.eraseOp(newBatch);
@@ -155,7 +161,8 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
auto blockArg = newBatch.getBody().front().insertArgument( auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc); 1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation()); rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch); return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
} }
@@ -4,7 +4,6 @@
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
@@ -131,7 +130,7 @@ struct MaterializerState {
DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance; DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance;
DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots; DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots;
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> producerDestClasses; DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues; DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues;
DenseMap<Value, Value> hostReplacements; DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps; DenseSet<Operation*> oldComputeOps;
@@ -574,6 +573,77 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened)); return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
} }
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
SmallVector<APInt, 8> elements;
elements.reserve(values.size());
for (int64_t value : values)
elements.push_back(APInt(64, value));
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
auto attr = DenseIntElementsAttr::get(type, elements);
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder);
}
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
SmallVector<int64_t, 8> widened;
widened.reserve(values.size());
for (int32_t value : values)
widened.push_back(value);
return createIndexTensorConstant(state, anchor, ArrayRef<int64_t>(widened));
}
bool allEqual(ArrayRef<int64_t> values) {
assert(!values.empty() && "expected at least one value");
for (int64_t value : values.drop_front())
if (value != values.front())
return false;
return true;
}
bool allEqual(ArrayRef<int32_t> values) {
assert(!values.empty() && "expected at least one value");
for (int32_t value : values.drop_front())
if (value != values.front())
return false;
return true;
}
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int64_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
if (allEqual(values))
return createIndexConstant(state, materializedClass.op, values.front());
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected compute_batch lane argument");
Value table = createIndexTensorConstant(state, materializedClass.op, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int32_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
if (allEqual(values))
return createIndexConstant(state, materializedClass.op, values.front());
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected compute_batch lane argument");
Value table = createIndexTensorConstant(state, materializedClass.op, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
FailureOr<SmallVector<ComputeInstance, 8>> FailureOr<SmallVector<ComputeInstance, 8>>
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) { getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
SmallVector<ComputeInstance, 8> peers; SmallVector<ComputeInstance, 8> peers;
@@ -623,14 +693,12 @@ Value createOriginalLaneValue(MaterializerState& state,
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult(); return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
} }
SmallVector<APInt, 8> laneValues; SmallVector<int64_t, 8> laneValues;
laneValues.reserve(peers.size()); laneValues.reserve(peers.size());
for (const ComputeInstance& peer : peers) for (const ComputeInstance& peer : peers)
laneValues.push_back(APInt(64, peer.laneStart)); laneValues.push_back(peer.laneStart);
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType()); Value table = createIndexTensorConstant(state, materializedClass.op, laneValues);
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
} }
@@ -659,6 +727,12 @@ bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps)
return false; return false;
} }
void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) {
SmallVector<ClassId, 4>& destinations = state.producerDestClasses[key];
if (!llvm::is_contained(destinations, classId))
destinations.push_back(classId);
}
void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) { void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) {
SmallVector<OpOperand*> uses; SmallVector<OpOperand*> uses;
for (OpOperand& use : oldValue.getUses()) for (OpOperand& use : oldValue.getUses())
@@ -693,7 +767,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
if (sourceClass == targetClass) if (sourceClass == targetClass)
continue; continue;
state.producerDestClasses[producerKey].insert(targetClass); appendDestinationClass(state, producerKey, targetClass);
} }
} }
} }
@@ -714,29 +788,70 @@ bool haveSameDestinationClasses(MaterializerState& state, ArrayRef<ProducerKey>
return true; return true;
auto firstIt = state.producerDestClasses.find(keys.front()); auto firstIt = state.producerDestClasses.find(keys.front());
DenseSet<ClassId> empty; ArrayRef<ClassId> first = firstIt == state.producerDestClasses.end() ? ArrayRef<ClassId>() : firstIt->second;
const DenseSet<ClassId>& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second;
for (ProducerKey key : keys.drop_front()) { for (ProducerKey key : keys.drop_front()) {
auto it = state.producerDestClasses.find(key); auto it = state.producerDestClasses.find(key);
const DenseSet<ClassId>& current = it == state.producerDestClasses.end() ? empty : it->second; ArrayRef<ClassId> current = it == state.producerDestClasses.end() ? ArrayRef<ClassId>() : it->second;
if (first.size() != current.size()) if (first.size() != current.size())
return false; return false;
for (ClassId classId : first) for (auto [lhs, rhs] : llvm::zip(first, current))
if (!current.contains(classId)) if (lhs != rhs)
return false; return false;
} }
return true; return true;
} }
SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, ProducerKey key) { ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey key) {
SmallVector<ClassId, 4> destinations;
auto it = state.producerDestClasses.find(key); auto it = state.producerDestClasses.find(key);
if (it == state.producerDestClasses.end()) if (it == state.producerDestClasses.end())
return destinations; return {};
for (ClassId classId : it->second) return it->second;
destinations.push_back(classId); }
llvm::sort(destinations);
return destinations; void appendSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one send");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (sourceClass.isBatch) {
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
return;
}
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
}
}
Value appendScalarReceive(MaterializerState& state,
MaterializedClass& targetClass,
Type type,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class");
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, targetClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId);
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue)
.getOutput();
} }
Value appendReceive(MaterializerState& state, Value appendReceive(MaterializerState& state,
@@ -746,50 +861,169 @@ Value appendReceive(MaterializerState& state,
ArrayRef<int32_t> sourceCoreIds, ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds, ArrayRef<int32_t> targetCoreIds,
Location loc) { Location loc) {
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive");
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds);
if (targetClass.isBatch) { if (targetClass.isBatch) {
return SpatChannelReceiveBatchOp::create( Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc);
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc);
.getOutput(); Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc);
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput();
} }
if (channelIds.size() != 1) { assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time");
return SpatChannelReceiveTensorOp::create( return appendScalarReceive(
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
.getOutput();
}
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
.getOutput();
} }
Value appendHostReceive(MaterializerState& state, Value appendPackedScalarReceives(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& targetClass,
Type type, Type fragmentType,
ArrayRef<int64_t> channelIds, Type packedType,
ArrayRef<int32_t> sourceCoreIds, ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> targetCoreIds, ArrayRef<int32_t> sourceCoreIds,
Location loc) { ArrayRef<int32_t> targetCoreIds,
state.rewriter.setInsertionPointAfter(sourceClass.op); Location loc) {
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class");
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive");
if (sourceClass.isBatch) { SmallVector<Value, 8> fragments;
return SpatChannelReceiveTensorOp::create( fragments.reserve(channelIds.size());
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
.getOutput(); fragments.push_back(appendScalarReceive(
state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc));
} }
assert(channelIds.size() == 1 && "scalar host receive expects one channel"); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front()) Value packed = fragments.front();
.getOutput(); if (fragments.size() != 1)
packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult();
if (packed.getType() != packedType)
packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult();
return packed;
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = payload;
return success();
}
if (!sourceClass.isBatch && !targetClass.isBatch) {
int64_t channelId = state.nextChannelId++;
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
SmallVector<int64_t, 1> channelIds {channelId};
SmallVector<int32_t, 1> sourceCoreIds {sourceCpu};
SmallVector<int32_t, 1> targetCoreIds {targetCpu};
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (!sourceClass.isBatch && targetClass.isBatch) {
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(targetClass.cpus.size());
sourceCoreIds.reserve(targetClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
for (CpuId targetCpu : targetClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(sourceCpu);
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey)
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication because source lanes are not contiguous");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType))
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication for non-static ranked tensor payload");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(targetCpu);
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendPackedScalarReceives(
state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.availableValues[*packedKey][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
} }
LogicalResult LogicalResult
@@ -821,207 +1055,50 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
if (!payloadType || !payloadType.hasStaticShape()) if (!payloadType || !payloadType.hasStaticShape())
return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor");
auto laneArg = batch.getLaneArgument();
if (!laneArg)
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
auto outputArg = batch.getOutputArgument(resultIndex);
if (!outputArg)
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult, 4> offsets; SmallVector<OpFoldResult, 4> offsets;
SmallVector<OpFoldResult, 4> sizes; SmallVector<OpFoldResult, 4> sizes;
SmallVector<OpFoldResult, 4> strides; SmallVector<OpFoldResult, 4> strides;
offsets.reserve(payloadType.getRank()); offsets.reserve(payloadType.getRank());
sizes.reserve(payloadType.getRank()); sizes.reserve(payloadType.getRank());
strides.reserve(payloadType.getRank()); strides.reserve(payloadType.getRank());
auto laneArg = batch.getLaneArgument();
if (!laneArg)
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
offsets.push_back(*laneArg); offsets.push_back(*laneArg);
sizes.push_back(state.rewriter.getIndexAttr(1)); sizes.push_back(state.rewriter.getIndexAttr(1));
strides.push_back(state.rewriter.getIndexAttr(1)); strides.push_back(state.rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
offsets.push_back(state.rewriter.getIndexAttr(0)); offsets.push_back(state.rewriter.getIndexAttr(0));
sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim)));
strides.push_back(state.rewriter.getIndexAttr(1)); strides.push_back(state.rewriter.getIndexAttr(1));
} }
auto outputArg = batch.getOutputArgument(resultIndex);
if (!outputArg)
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides); tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
return success(); return success();
} }
void appendScalarSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
}
void appendBatchSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload);
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = payload;
return success();
}
if (!sourceClass.isBatch && !targetClass.isBatch) {
int64_t channelId = state.nextChannelId++;
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
Value received = appendReceive(state,
targetClass,
payload.getType(),
ArrayRef<int64_t>(channelId),
ArrayRef<int32_t>(sourceCpu),
ArrayRef<int32_t>(targetCpu),
loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (!sourceClass.isBatch && targetClass.isBatch) {
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(targetClass.cpus.size());
sourceCoreIds.reserve(targetClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (CpuId targetCpu : targetClass.cpus) {
int64_t channelId = state.nextChannelId++;
channelIds.push_back(channelId);
sourceCoreIds.push_back(static_cast<int32_t>(sourceClass.cpus.front()));
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
appendScalarSend(state,
sourceClass,
payload,
channelId,
static_cast<int32_t>(sourceClass.cpus.front()),
static_cast<int32_t>(targetCpu),
loc);
}
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey)
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
"lanes are not contiguous in send order");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType))
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus.front()));
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.availableValues[*packedKey][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
return sourceClass.op->emitError("unhandled materialized communication pattern");
}
LogicalResult emitHostCommunication(MaterializerState& state, LogicalResult emitHostCommunication(MaterializerState& state,
MaterializedClass& sourceClass, MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys, ArrayRef<ProducerKey> keys,
Value payload, Value payload,
Value originalOutput, Value originalOutput,
Location loc) { Location loc) {
(void) keys;
(void) loc;
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
return success(); return success();
if (!sourceClass.hostOutputs.empty()) return setHostOutputValue(state, sourceClass, originalOutput, payload);
return setHostOutputValue(state, sourceClass, originalOutput, payload);
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(0);
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
state.hostReplacements[originalOutput] = received;
return success();
} }
LogicalResult emitOutputFanout(MaterializerState& state, LogicalResult emitOutputFanout(MaterializerState& state,
@@ -1034,7 +1111,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success(); return success();
if (!sourceClass.isBatch) { if (!sourceClass.isBatch) {
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
if (failed( if (failed(
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure(); return failure();
@@ -1048,7 +1125,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return sourceClass.op->emitError( return sourceClass.op->emitError(
"cannot materialize batched output whose lanes have different destination equivalence classes"); "cannot materialize batched output whose lanes have different destination equivalence classes");
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure(); return failure();
@@ -660,12 +660,12 @@ public:
emitMergeIrCounts("after-materialization", func); emitMergeIrCounts("after-materialization", func);
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) { /*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
emitMergeIrCounts("after-post-merge-compaction", func); emitMergeIrCounts("after-post-merge-compaction", func);*/
{ {
ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
@@ -120,7 +120,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; auto sizeInBytes = getShapedTypeSizeInBytes(initType);
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(), mapOp.getLoc(),
initType, initType,
@@ -176,9 +176,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes)) if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure(); return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (!hasByteSizedElementType(sourceType.getElementType()))
if (elementByteWidth <= 0)
return failure(); return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes) if (size != totalBytes)
@@ -31,13 +31,6 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
return false; return false;
} }
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
template <typename CoreOpTy> template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter, IRRewriter& rewriter,
@@ -82,7 +75,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
continue; continue;
} }
int64_t totalBytes = getValueSizeInBytes(originalValue); int64_t totalBytes = -1;
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape())
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type));
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true; hasFailure = true;