DeadLock
This commit is contained in:
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(op.getWeights().size());
|
||||
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||
auto weightArg = op.getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(op.getInputs().size());
|
||||
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||
auto inputArg = op.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
|
||||
|
||||
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreId)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
|
||||
auto laneArg = op.getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(op.getWeights().size());
|
||||
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||
auto weightArg = op.getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(op.getInputs().size());
|
||||
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||
auto inputArg = op.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
if (op.getNumResults() != 0) {
|
||||
outputArgs.reserve(op.getNumResults());
|
||||
for (unsigned index = 0; index < op.getNumResults(); ++index) {
|
||||
auto outputArg = op.getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << op.getLaneCount();
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||
if (op.getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
|
||||
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
template <typename ComputeBatchOpTy>
|
||||
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreId)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
auto laneArg = getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
if (getNumResults() != 0) {
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
|
||||
}
|
||||
|
||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||
|
||||
Reference in New Issue
Block a user