big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -71,38 +71,6 @@ def PimYieldOp : PimOp<"yield", [Terminator]> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimMapOp : PimOp<"map", [SingleBlock]> {
|
||||
let summary = "Apply the same lane-local region to many independent tensors";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimEmptyManyOp : PimOp<"empty_many", []> {
|
||||
let summary = "Create many identical empty tensors";
|
||||
|
||||
let results = (outs
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -121,18 +89,6 @@ def PimSendOp : PimOp<"send", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSendManyOp : PimOp<"send_many", []> {
|
||||
let summary = "Send multiple tensors to target cores";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendTensorOp : PimOp<"send_tensor", []> {
|
||||
let summary = "Send equal contiguous chunks of one tensor to target cores";
|
||||
|
||||
@@ -157,18 +113,6 @@ def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimSendManyBatchOp : PimOp<"send_many_batch", []> {
|
||||
let summary = "Send multiple per-lane tensors to target cores from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
@@ -193,28 +137,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive multiple tensors from source cores";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$outputBuffers,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBuffersMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive equal contiguous chunks from source cores into one tensor";
|
||||
|
||||
@@ -259,28 +181,6 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimReceiveManyBatchOp : PimOp<"receive_many_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive multiple per-lane tensors from source cores into a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$outputBuffers,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBuffersMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
@@ -385,32 +285,6 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimExtractRowsOp : PimOp<"extract_rows", [DestinationStyleOpInterface]> {
|
||||
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
Variadic<PimTensor>:$outputBuffers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBuffersMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimConcatOp : PimOp<"concat", [DestinationStyleOpInterface]> {
|
||||
let summary = "Concatenate tensors";
|
||||
|
||||
|
||||
@@ -147,69 +147,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimEmptyManyOp::print(OpAsmPrinter& printer) {
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " x" << getOutputs().size();
|
||||
}
|
||||
|
||||
ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
int64_t resultCount = 0;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)
|
||||
|| parser.parseKeyword("x") || parser.parseInteger(resultCount))
|
||||
return failure();
|
||||
|
||||
if (resultCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive");
|
||||
|
||||
SmallVector<Type> resultTypes(resultCount, outputType);
|
||||
result.addTypes(resultTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -237,36 +174,6 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -294,72 +201,6 @@ ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
@@ -434,77 +275,6 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimExtractRowsOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
Type inputType;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (parser.resolveOperand(input, inputType, result.operands)
|
||||
|| parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis() << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
|
||||
@@ -13,12 +13,6 @@ namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||
if (coreIds.size() != valueCount)
|
||||
return op->emitError("core id metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||
}
|
||||
@@ -33,28 +27,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
|
||||
Type firstType = types.front();
|
||||
auto firstShapedType = dyn_cast<ShapedType>(firstType);
|
||||
bool firstIsTensor = isa<RankedTensorType>(firstType);
|
||||
bool firstIsMemRef = isa<MemRefType>(firstType);
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!firstShapedType || !shapedType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
||||
return op->emitError() << kind << " values must all use the same shaped container kind";
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType()
|
||||
|| firstShapedType.getShape() != shapedType.getShape())
|
||||
return op->emitError() << kind << " values must all have the same shape and element type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
@@ -74,109 +46,12 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
return failure();
|
||||
return coreBatchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside pim.core_batch");
|
||||
if (coreIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("core id metadata length must match the number of values times parent laneCount");
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimEmptyManyOp::verify() {
|
||||
if (getOutputs().empty())
|
||||
return emitError("must produce at least one output");
|
||||
|
||||
Type firstType = getOutputs().front().getType();
|
||||
auto firstShapedType = dyn_cast<ShapedType>(firstType);
|
||||
if (!firstShapedType || !firstShapedType.hasRank())
|
||||
return emitError("outputs must all be ranked shaped types");
|
||||
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != firstType)
|
||||
return emitError("outputs must all have the same type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with pim.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyOp::verify() {
|
||||
if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many")))
|
||||
return failure();
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many")))
|
||||
return failure();
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
|
||||
if (outputBuffer.getType() != output.getType())
|
||||
return emitError("output buffers and outputs must have matching types");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
@@ -185,61 +60,6 @@ LogicalResult PimReceiveTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyBatchOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch")))
|
||||
return failure();
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch")))
|
||||
return failure();
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
|
||||
if (outputBuffer.getType() != output.getType())
|
||||
return emitError("output buffers and outputs must have matching types");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimExtractRowsOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
|
||||
auto inputType = dyn_cast<ShapedType>(getInput().getType());
|
||||
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
|
||||
return emitError("input must be a rank-2 shaped type");
|
||||
|
||||
int64_t numRows = inputType.getShape()[0];
|
||||
int64_t numCols = inputType.getShape()[1];
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
if (numRows >= 0 && static_cast<int64_t>(getOutputs().size()) != numRows)
|
||||
return emitError("number of outputs must match the number of input rows");
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match")))
|
||||
return failure();
|
||||
|
||||
auto outputType = dyn_cast<ShapedType>(output.getType());
|
||||
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
|
||||
return emitError("outputs must all be rank-2 shaped types");
|
||||
if (!haveSameShapedContainerKind(getInput().getType(), output.getType()))
|
||||
return emitError("outputs must use the same shaped container kind as the input");
|
||||
if (outputType.getElementType() != elementType)
|
||||
return emitError("output element types must match input element type");
|
||||
auto outputShape = outputType.getShape();
|
||||
if (outputShape[0] != 1)
|
||||
return emitError("each output must have exactly one row");
|
||||
if (numCols >= 0 && outputShape[1] != numCols)
|
||||
return emitError("output column count must match input column count");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
@@ -180,39 +180,6 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<Receive
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveManyOpInterface, PimReceiveManyOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveManyOp>(op);
|
||||
SmallVector<Value> outputBuffers;
|
||||
SmallVector<Type> resultTypes;
|
||||
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
||||
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
||||
|
||||
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
outputBuffers.push_back(*outputBufferOpt);
|
||||
resultTypes.push_back(outputBufferOpt->getType());
|
||||
}
|
||||
|
||||
auto newOp = PimReceiveManyOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
TypeRange(resultTypes),
|
||||
ValueRange(outputBuffers),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
rewriter.replaceOp(receiveOp, newOp.getOutputs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveTensorOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
@@ -234,77 +201,6 @@ struct ReceiveTensorOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveManyBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<ReceiveManyBatchOpInterface, PimReceiveManyBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveManyBatchOp>(op);
|
||||
SmallVector<Value> outputBuffers;
|
||||
SmallVector<Type> resultTypes;
|
||||
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
||||
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
||||
|
||||
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
outputBuffers.push_back(*outputBufferOpt);
|
||||
resultTypes.push_back(outputBufferOpt->getType());
|
||||
}
|
||||
|
||||
auto newOp = PimReceiveManyBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
TypeRange(resultTypes),
|
||||
ValueRange(outputBuffers),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
rewriter.replaceOp(receiveOp, newOp.getOutputs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractRowsOpInterface, PimExtractRowsOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto extractRowsOp = cast<PimExtractRowsOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> outputBuffers;
|
||||
SmallVector<Type> resultTypes;
|
||||
outputBuffers.reserve(extractRowsOp.getOutputBuffers().size());
|
||||
resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
|
||||
|
||||
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
outputBuffers.push_back(*outputBufferOpt);
|
||||
resultTypes.push_back(outputBufferOpt->getType());
|
||||
}
|
||||
|
||||
auto newOp = PimExtractRowsOp::create(rewriter,
|
||||
extractRowsOp.getLoc(),
|
||||
TypeRange(resultTypes),
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
ValueRange(outputBuffers));
|
||||
rewriter.replaceOp(extractRowsOp, newOp.getOutputs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -334,31 +230,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
||||
}
|
||||
};
|
||||
|
||||
struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOpInterface, PimEmptyManyOp> {
|
||||
bool bufferizesToAllocation(Operation* op, Value value) const { return true; }
|
||||
|
||||
bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto emptyManyOp = cast<PimEmptyManyOp>(op);
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(emptyManyOp.getOutputs().size());
|
||||
for (Value output : emptyManyOp.getOutputs()) {
|
||||
auto shapedType = cast<ShapedType>(output.getType());
|
||||
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
|
||||
}
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimEmptyManyOp>(rewriter, emptyManyOp, TypeRange(resultTypes));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
@@ -383,7 +254,7 @@ struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensor
|
||||
}
|
||||
};
|
||||
|
||||
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
|
||||
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
@@ -392,75 +263,93 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto mapOp = cast<PimMapOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||
|| mapOp.getInputs().empty())
|
||||
return {};
|
||||
|
||||
return {
|
||||
{&mapOp->getOpOperand(0), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto mapOp = cast<PimMapOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||
|| mapOp.getInputs().empty())
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto sendOp = cast<PimSendOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
|
||||
if (inputType)
|
||||
return inputType;
|
||||
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
auto shapedType = cast<ShapedType>(mapOp.getInputs().front().getType());
|
||||
return BufferLikeType(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
|
||||
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto mapOp = cast<PimMapOp>(op);
|
||||
auto sendOp = cast<PimSendBatchOp>(op);
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
SmallVector<Type> resultTypes;
|
||||
inputs.reserve(mapOp.getInputs().size());
|
||||
resultTypes.reserve(mapOp.getOutputs().size());
|
||||
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
for (Value input : mapOp.getInputs()) {
|
||||
if (isa<TensorType>(input.getType())) {
|
||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||
if (failed(inputOpt))
|
||||
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(coreOp.getWeights().size());
|
||||
for (Value weight : coreOp.getWeights()) {
|
||||
if (isa<TensorType>(weight.getType())) {
|
||||
auto weightOpt = getBufferOrValue(rewriter, weight, options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
inputs.push_back(*inputOpt);
|
||||
weights.push_back(*weightOpt);
|
||||
}
|
||||
else {
|
||||
inputs.push_back(input);
|
||||
weights.push_back(weight);
|
||||
}
|
||||
}
|
||||
|
||||
for (Value output : mapOp.getOutputs()) {
|
||||
auto shapedType = cast<ShapedType>(output.getType());
|
||||
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
auto newOp = PimMapOp::create(rewriter, mapOp.getLoc(), TypeRange(resultTypes), ValueRange(inputs));
|
||||
rewriter.inlineRegionBefore(mapOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
||||
rewriter.setInsertionPoint(coreOp);
|
||||
auto newOp = PimCoreOp::create(rewriter, coreOp.getLoc(), ValueRange(weights), coreOp.getCoreIdAttr());
|
||||
rewriter.inlineRegionBefore(coreOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
||||
for (Block& block : newOp.getBody())
|
||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(mapOp, newOp.getOutputs());
|
||||
rewriter.eraseOp(coreOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -730,16 +619,14 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimEmptyManyOp::attachInterface<EmptyManyOpInterface>(*ctx);
|
||||
PimMapOp::attachInterface<MapOpInterface>(*ctx);
|
||||
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
|
||||
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
|
||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
|
||||
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
||||
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
|
||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Threading.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
@@ -47,53 +46,18 @@ private:
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
auto moduleOp = getOperation();
|
||||
// Refactor this into a function
|
||||
{
|
||||
auto funcOp = *getPimEntryFunc(moduleOp);
|
||||
auto funcOp = *getPimEntryFunc(moduleOp);
|
||||
|
||||
SmallVector<Operation*> coreOps;
|
||||
funcOp->walk<WalkOrder::PreOrder>([&](Operation* op) {
|
||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||
coreOps.push_back(op);
|
||||
});
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
||||
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) {
|
||||
// Again, allocate state LOCALLY per thread/function
|
||||
bufferization::OneShotBufferizationOptions options;
|
||||
options.allowUnknownOps = true;
|
||||
if (isa<pim::PimCoreBatchOp>(coreOp))
|
||||
options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; });
|
||||
bufferization::BufferizationState state;
|
||||
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
||||
coreOp->emitError("Failed to bufferize PIM and Spatial ops");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
});
|
||||
bufferization::OneShotBufferizationOptions options;
|
||||
options.allowUnknownOps = true;
|
||||
options.bufferizeFunctionBoundaries = true;
|
||||
options.setFunctionBoundaryTypeConversion(bufferization::LayoutMapOption::IdentityLayoutMap);
|
||||
bufferization::BufferizationState state;
|
||||
|
||||
if (failed(result)) {
|
||||
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
||||
if (llvm::isa_and_present<pim::PimCoreOp, pim::PimCoreBatchOp>(toTensorOp->getParentOp()))
|
||||
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
||||
});
|
||||
|
||||
// One-Shot-Bufferization
|
||||
bufferization::OneShotBufferizationOptions options;
|
||||
options.allowUnknownOps = true;
|
||||
options.opFilter.denyOperation([](Operation* op) {
|
||||
return op->getParentOfType<pim::PimCoreOp>() || op->getParentOfType<pim::PimCoreBatchOp>();
|
||||
});
|
||||
bufferization::BufferizationState state;
|
||||
|
||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||
signalPassFailure();
|
||||
}
|
||||
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, state))) {
|
||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
@@ -119,30 +83,6 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove toTensor operations: leave memrefs instead
|
||||
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
||||
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
||||
toTensorOp.erase();
|
||||
});
|
||||
|
||||
// Change main function return types from tensors to memrefs
|
||||
func::FuncOp funcOp;
|
||||
for (Operation& op : moduleOp.getBody()->getOperations())
|
||||
if ((funcOp = dyn_cast<func::FuncOp>(&op)))
|
||||
break;
|
||||
auto oldFuncType = funcOp.getFunctionType();
|
||||
SmallVector<Type> newResults;
|
||||
bool changed = false;
|
||||
for (Type type : oldFuncType.getResults())
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||
newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
|
||||
changed = true;
|
||||
}
|
||||
else
|
||||
newResults.push_back(type);
|
||||
if (changed)
|
||||
funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults));
|
||||
|
||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||
|
||||
// Dump to file for debug
|
||||
|
||||
Reference in New Issue
Block a user