304 lines
12 KiB
C++
304 lines
12 KiB
C++
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace pim {
|
|
|
|
namespace {
|
|
|
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
|
if (isa<PimMemCopyHostToDevOp>(op))
|
|
return operandIndex == 3;
|
|
if (isa<PimMemCopyHostToDevBatchOp>(op))
|
|
return operandIndex == 1;
|
|
if (isa<PimMemCopyDevToHostOp>(op))
|
|
return operandIndex == 2;
|
|
return false;
|
|
}
|
|
|
|
static Region* getParentRegion(Value value) {
|
|
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
|
return blockArgument.getParentRegion();
|
|
Operation* definingOp = value.getDefiningOp();
|
|
return definingOp ? definingOp->getParentRegion() : nullptr;
|
|
}
|
|
|
|
static bool isDefinedInsideRegion(Value value, Region& region) {
|
|
Region* parentRegion = getParentRegion(value);
|
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
|
}
|
|
|
|
static bool isConstantExternalValue(Value value) {
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return false;
|
|
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
|
return true;
|
|
|
|
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
|
|
if (!getGlobalOp)
|
|
return false;
|
|
|
|
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
|
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
return globalOp && globalOp.getConstant();
|
|
}
|
|
|
|
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
|
bool hasFailure = false;
|
|
region.walk([&](Operation* op) {
|
|
for (OpOperand& operand : op->getOpOperands()) {
|
|
Value value = operand.get();
|
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|
|
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
|
continue;
|
|
|
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
|
<< kind << " body may only directly reference external constants";
|
|
diagnostic.attachNote(op->getLoc())
|
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
|
hasFailure = true;
|
|
}
|
|
});
|
|
return success(!hasFailure);
|
|
}
|
|
|
|
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
|
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
|
}
|
|
|
|
static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) {
|
|
auto lhsShaped = dyn_cast<ShapedType>(lhs);
|
|
auto rhsShaped = dyn_cast<ShapedType>(rhs);
|
|
if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs))
|
|
return op->emitError(message);
|
|
if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape())
|
|
return op->emitError(message);
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
|
if (coreIds.empty())
|
|
return op->emitError() << kind << " must carry at least one chunk";
|
|
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
if (!shapedType || !shapedType.hasStaticShape())
|
|
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
|
|
|
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
return op->emitError() << kind << " requires byte-sized elements";
|
|
|
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
|
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
|
if (coreIds.empty())
|
|
return op->emitError() << kind << " must carry at least one chunk";
|
|
|
|
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
|
if (!coreBatchOp)
|
|
return op->emitError() << kind << " must be nested inside pim.core_batch";
|
|
|
|
int32_t laneCount = coreBatchOp.getLaneCount();
|
|
if (laneCount <= 0)
|
|
return op->emitError() << kind << " requires a positive parent laneCount";
|
|
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
|
|
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
|
|
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
if (!shapedType || !shapedType.hasStaticShape())
|
|
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
|
|
|
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
return op->emitError() << kind << " requires byte-sized elements";
|
|
|
|
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
if (totalBytes % chunkCount != 0)
|
|
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
|
|
|
return success();
|
|
}
|
|
|
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
|
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
|
if (!shapedType)
|
|
return failure();
|
|
return shapedType.getShape();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult PimCoreOp::verify() {
|
|
Block& block = getBody().front();
|
|
if (block.getNumArguments() != getWeights().size())
|
|
return emitError("core body must have one block argument per weight");
|
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
|
return emitError("core weight block argument types must match weight operand types exactly");
|
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
|
}
|
|
|
|
LogicalResult PimCoreBatchOp::verify() {
|
|
if (getLaneCount() <= 0)
|
|
return emitError("laneCount must be positive");
|
|
Block& block = getBody().front();
|
|
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
|
|
if (block.getNumArguments() != expectedArgCount)
|
|
return emitError("core_batch body must have lane, weight, and input block arguments");
|
|
if (!getLaneArgument().getType().isIndex())
|
|
return emitError("core_batch first block argument must have index type");
|
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
|
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
|
if (getInputArgument(inputIndex).getType() != input.getType())
|
|
return emitError("core_batch input block argument types must match input operand types exactly");
|
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
|
}
|
|
|
|
LogicalResult PimSendTensorOp::verify() {
|
|
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
|
}
|
|
|
|
LogicalResult PimSendTensorBatchOp::verify() {
|
|
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
|
|
}
|
|
|
|
LogicalResult PimReceiveTensorOp::verify() {
|
|
if (failed(verifyCompatibleShapedTypes(
|
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
return failure();
|
|
|
|
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
|
}
|
|
|
|
LogicalResult PimReceiveTensorBatchOp::verify() {
|
|
if (failed(verifyCompatibleShapedTypes(
|
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
return failure();
|
|
|
|
return verifyTensorBatchCommunication(
|
|
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
|
}
|
|
|
|
LogicalResult PimVMMOp::verify() {
|
|
if (failed(verifyCompatibleShapedTypes(
|
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
return failure();
|
|
|
|
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
|
|
if (failed(matrixShapeOpt))
|
|
return emitError("weight must be a shaped value");
|
|
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
|
|
|
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
|
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
|
if (!vectorType || !outputType)
|
|
return emitError("input and output must be shaped types");
|
|
|
|
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
|
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
|
return emitError("matrix, vector and output must have rank 2");
|
|
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
if (N <= 0 || M <= 0)
|
|
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
|
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
|
return emitError("matrix dimensions must fit in one crossbar");
|
|
|
|
int64_t vector1 = vectorShape[0];
|
|
int64_t vectorWidth = vectorShape[1];
|
|
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
|
return emitError("vector shape must be (1, crossbar-size)");
|
|
|
|
int64_t output1 = outputShape[0];
|
|
int64_t outputWidth = outputShape[1];
|
|
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
|
return emitError("output shape must be (1, crossbar-size)");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult PimConcatOp::verify() {
|
|
if (getInputs().empty())
|
|
return emitError("requires at least one input");
|
|
|
|
if (failed(verifyCompatibleShapedTypes(
|
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
|
return failure();
|
|
|
|
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
|
if (!outputType || !outputType.hasRank())
|
|
return emitError("output must be a ranked shaped type");
|
|
|
|
int64_t axis = getAxis();
|
|
int64_t rank = outputType.getRank();
|
|
if (axis < 0 || axis >= rank)
|
|
return emitError("axis must be within the output rank");
|
|
|
|
int64_t concatenatedDimSize = 0;
|
|
bool concatenatedDimDynamic = false;
|
|
Type outputElementType = outputType.getElementType();
|
|
|
|
for (Value input : getInputs()) {
|
|
auto inputType = dyn_cast<ShapedType>(input.getType());
|
|
if (!inputType || !inputType.hasRank())
|
|
return emitError("inputs must be ranked shaped types");
|
|
if (!haveSameShapedContainerKind(input.getType(), getOutput().getType()))
|
|
return emitError("inputs and output must use the same shaped container kind");
|
|
if (inputType.getRank() != rank)
|
|
return emitError("all inputs must have the same rank as the output");
|
|
if (inputType.getElementType() != outputElementType)
|
|
return emitError("all inputs must have the same element type as the output");
|
|
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
if (dim == axis)
|
|
continue;
|
|
int64_t inputDim = inputType.getDimSize(dim);
|
|
int64_t outputDim = outputType.getDimSize(dim);
|
|
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
|
|
return emitError("non-concatenated dimensions must match the output shape");
|
|
}
|
|
|
|
int64_t inputConcatDim = inputType.getDimSize(axis);
|
|
if (ShapedType::isDynamic(inputConcatDim)) {
|
|
concatenatedDimDynamic = true;
|
|
continue;
|
|
}
|
|
concatenatedDimSize += inputConcatDim;
|
|
}
|
|
|
|
int64_t outputConcatDim = outputType.getDimSize(axis);
|
|
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
|
|
return emitError("output concatenated dimension must equal the sum of input sizes");
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace pim
|
|
} // namespace onnx_mlir
|