#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include using namespace mlir; namespace onnx_mlir { namespace spatial { BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); } BlockArgument SpatCompute::getInputArgument(unsigned idx) { return getBody().front().getArgument(getWeights().size() + idx); } void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { if (region.empty()) return; for (unsigned index = 0; index < getWeights().size(); ++index) setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getInputs().size(); ++index) setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); } BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); } BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); } BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) { return getBody().front().getArgument(1 + getWeights().size() + idx); } BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) { return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx); } void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { if (region.empty()) return; setNameFn(getLaneArgument(), "lane"); for (unsigned index = 0; index < getWeights().size(); ++index) setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getInputs().size(); ++index) setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getNumResults(); ++index) { if (index == 0) { setNameFn(getOutputArgument(index), "out"); continue; } setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str()); } } void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) { OpBuilder::InsertionGuard guard(builder); Region* bodyRegion = result.addRegion(); builder.createBlock(bodyRegion); } OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } llvm::iterator_range SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); } void SpatialDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" >(); } } // namespace spatial } // namespace onnx_mlir //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" #define GET_TYPEDEF_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"