diff --git a/docs/developers/pass_pipeline.md b/docs/developers/pass_pipeline.md index 9215675a191..84e48a60c67 100644 --- a/docs/developers/pass_pipeline.md +++ b/docs/developers/pass_pipeline.md @@ -556,11 +556,11 @@ specialized into: scf.if %cond { "lmhlo.fusion"() ( { ... - }) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_thread_per_block_hint = 256 : i32, disc_vectorize_or_tile_hint = 2 : i32} : () -> () + }) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_cta_size_hint = 256 : i32, disc_vectorize_or_tile_hint = 2 : i32} : () -> () } else { "lmhlo.fusion"() ( { ... - }) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2X_no_vectile", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_thread_per_block_hint = 256 : i32, disc_vectorize_or_tile_hint = 1 : i32} : () -> () + }) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2X_no_vectile", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_cta_size_hint = 256 : i32, disc_vectorize_or_tile_hint = 1 : i32} : () -> () ``` The different "disc_vectorize_or_tile_hint" attributes will guide the codegen passes to diff --git a/tao_compiler/mlir/disc/BUILD b/tao_compiler/mlir/disc/BUILD index 8d2e00a5032..64b411ed428 100644 --- a/tao_compiler/mlir/disc/BUILD +++ b/tao_compiler/mlir/disc/BUILD @@ -1351,6 +1351,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", "@llvm-project//llvm:Support", "@llvm-project//mlir:Transforms", @@ -2241,6 +2242,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "disc_erase_buffer_deallocation", + srcs = ["transforms/disc_erase_buffer_deallocation.cc"], + deps = [ + ":lmhlo_disc", + ":disc_util", + ":pass_details", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:BufferizationTransforms", + ], + alwayslink = 1, +) + cc_library( name = "all_passes", hdrs = [ @@ -2264,12 +2284,12 @@ cc_library( ":disc_dense_to_sparse", ":disc_convert_const_to_ral", ":disc_convert_fake_quant_op", - ":disc_custom_call_rewriter", ":disc_cpu_map_parallel_loop", + ":disc_custom_call_rewriter", ":disc_duplicate_computation_after_fusion", ":disc_duplicate_computation_for_fusion", ":disc_dynamic_slice_converter", - ":disc_sparse_op_rewriter", + ":disc_erase_buffer_deallocation", ":disc_flatten_memref_access", ":disc_for_loop_unroll_interleave", ":disc_fuse_splat_const", @@ -2294,6 +2314,7 @@ cc_library( ":disc_shape_optimization", ":disc_shape_simplifier", ":disc_shape_to_std", + ":disc_sparse_op_rewriter", ":disc_specialize_fusion_with_speculation", ":disc_std_bufferize", ":disc_stitch_fusion", diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 2017fcf59ae..0f1688bc453 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -49,6 +49,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/AsmState.h" @@ -518,7 +519,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { tensorflow::ReadStringFromEnvVar("DISC_TRANSFORM_SCHEDULE_FILE", "", &transform_schedule); pm.addNestedPass(disc_ral::createDiscTransformLegalizeToLoopPass( - gpu_enabled, transform_schedule)); + gpu_enabled, transform_schedule, options.gpu_info.cc_major, + options.gpu_info.cc_minor)); } pm.addNestedPass(createCanonicalizerPass()); @@ -599,7 +601,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { // TODO: adopt tileSize from attributes of speculation pass with a // wrapper of the original ParallelLoopTilingPass pm.addNestedPass( - disc_ral::createParallelLoopTilingPass({kThreadsRowReduction}, true)); + disc_ral::createParallelLoopTilingPass({kCTASizeDefault}, true)); // pm.addNestedPass(disc_ral::createMapParallelLoopsPass()); pm.addNestedPass(mlir::createGpuMapParallelLoopsPass()); @@ -640,6 +642,10 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { kernelPm.addPass(createLoopInvariantCodeMotionPass()); kernelPm.addPass(createCSEPass()); } + kernelPm.addNestedPass( + disc_ral::createDiscEraseBufferDeallocationPass()); + kernelPm.addNestedPass( + memref::createExpandStridedMetadataPass()); kernelPm.addPass(createConvertSCFToCFPass()); kernelPm.addPass(createLowerAffinePass()); kernelPm.addNestedPass(createCanonicalizerPass()); diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/default_schedule_matmul_nn_s_256x256x128_f16.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/default_schedule_matmul_nn_s_256x256x128_f16.mlir new file mode 100644 index 00000000000..24ffa196b5c --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/default_schedule_matmul_nn_s_256x256x128_f16.mlir @@ -0,0 +1,9 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} { + func.func @main(%arg0: tensor<256x128xf16>, %arg1: tensor<128x256xf16>) -> (tensor<256x256xf16>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} { + %graph = tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<256x128xf16>, tensor<128x256xf16>) -> (tensor<256x256xf16>) + tf_executor.fetch %0 : tensor<256x256xf16> + } + return %graph : tensor<256x256xf16> + } +} diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_256x256x128_f16.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_256x256x128_f16.mlir new file mode 100644 index 00000000000..24ffa196b5c --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_256x256x128_f16.mlir @@ -0,0 +1,9 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} { + func.func @main(%arg0: tensor<256x128xf16>, %arg1: tensor<128x256xf16>) -> (tensor<256x256xf16>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} { + %graph = tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<256x128xf16>, tensor<128x256xf16>) -> (tensor<256x256xf16>) + tf_executor.fetch %0 : tensor<256x256xf16> + } + return %graph : tensor<256x256xf16> + } +} diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule.mlir new file mode 100644 index 00000000000..5eb0bfb52f4 --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule.mlir @@ -0,0 +1,45 @@ +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match attributes {disc.transform.name = "dot_general"} in %arg0 : (!transform.any_op) -> !transform.any_op + %1:2 = split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %forall_op, %tiled_op = transform.structured.tile_to_forall_op %1#1 num_threads [] tile_sizes [128, 128](mapping = [#gpu.block, #gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1#0 into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %for_op, %splitted_op = transform.disc.split_reduction_serial %tiled_op by tile_sizes = [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %promoted_dot, %lhs_alloc, %rhs_alloc = transform.disc.promote_dot_operands %for_op [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %forall_op_0, %tiled_op_1 = transform.structured.tile_to_forall_op %promoted_dot num_threads [] tile_sizes [64, 64](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %for_op_2, %splitted_op_3 = transform.disc.split_reduction_serial %tiled_op_1 by tile_sizes = [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_linalg_op, %loops:3 = transform.structured.tile %for_op_2[16, 8, 16] {interchange = [0, 1, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.disc.apply_licm %arg0 : !transform.any_op + transform.disc.apply_dce %arg0 : !transform.any_op + transform.disc.apply_cse %arg0 : !transform.any_op + %2 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = transform.disc.apply_patterns %2 {canonicalization} : (!transform.any_op) -> !transform.any_op + %4 = transform.structured.vectorize %3 {vectorize_padding} : (!transform.any_op) -> !transform.any_op + transform.disc.apply_dce %arg0 : !transform.any_op + transform.disc.apply_cse %arg0 : !transform.any_op + %5 = transform.disc.bufferize {target_gpu} %arg0 : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.erase_dealloc %6 : (!transform.any_op) -> () + %7 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.transfer_write_zero_to_scf %7 : (!transform.any_op) -> () + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %8 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.block, #gpu.block]} in %5 : (!transform.any_op) -> !transform.any_op + %9 = transform.disc.forall_to_gpu_ctas %8 : (!transform.any_op) -> !transform.any_op + %10 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.warp, #gpu.warp]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.forall_to_gpu_warps %10 : (!transform.any_op) -> () + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %11 = transform.structured.match ops{["linalg.generic"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.gmem_to_smem %11 : (!transform.any_op) -> () + %12 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.vector.vector_to_mma_conversion %12 : (!transform.any_op) -> () + transform.disc.apply_licm %5 : !transform.any_op + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %13 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.inline_and_convert_gpu_ids %13 : (!transform.any_op) -> () + transform.disc.apply_licm %5 : !transform.any_op + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op +} diff --git a/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc index 5951a488a47..2f6f613abf2 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc @@ -270,4 +270,22 @@ TEST(PackedMatmul, F32_768x3072_Using_Default_Schedule) { /*profiling*/ true)); } +TEST(Matmul, F16_256x256x128_Using_Default_Schedule) { + EnvSetting setting = {{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, + {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, + {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; + EnvSettingContext ctx(setting); + EXPECT_TRUE(feature_test_main( + /*mlir_file_path*/ c_ft_path + + "default_schedule_matmul_nn_s_256x256x128_f16.mlir", + /*backend_types*/ {BackendType::kCuda}, + /*num_inputs*/ 2, + /*num_outputs*/ 1, + /*input_descriptors*/ {"256x128xf16_X", "128x256xf16_X"}, + /*output_descriptors*/ {"f16_X"}, + /*input_vals*/ {}, + /*expected_output_vals*/ {}, + /*profiling*/ true)); +} + } // namespace mlir_test diff --git a/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc index 88f1841dd8c..b9298675be2 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/matmul.cc @@ -33,7 +33,7 @@ static bool init_threads = []() { TEST(SimpleTest, MatMulF32_11x13x12) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( @@ -48,7 +48,7 @@ TEST(SimpleTest, MatMulF32_11x13x12) { TEST(SimpleTest, MatMulF32_111x131x121) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}}; EnvSettingContext ctx(setting); EXPECT_TRUE(feature_test_main( @@ -63,7 +63,8 @@ TEST(SimpleTest, MatMulF32_111x131x121) { TEST(SimpleTest, MatMulF32_304x1024x256) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -83,7 +84,8 @@ TEST(SimpleTest, MatMulF32_304x1024x256) { TEST(SimpleTest, MatMulF32_1024x1024x1024) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}}, + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", + false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; @@ -103,7 +105,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024) { TEST(SimpleTest, MatMulF32_304x1024x256_2) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -124,7 +126,7 @@ TEST(SimpleTest, MatMulF32_304x1024x256_2) { TEST(SimpleTest, MatMulF32_1024x1024x1024_2) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -145,7 +147,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_2) { TEST(SimpleTest, MatMulF32_304x256x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -166,7 +168,7 @@ TEST(SimpleTest, MatMulF32_304x256x256_3) { TEST(SimpleTest, MatMulF32_304x512x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -187,7 +189,7 @@ TEST(SimpleTest, MatMulF32_304x512x256_3) { TEST(SimpleTest, MatMulF32_304x1024x256_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -208,7 +210,7 @@ TEST(SimpleTest, MatMulF32_304x1024x256_3) { TEST(SimpleTest, MatMulF32_304x1024x512_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -229,7 +231,7 @@ TEST(SimpleTest, MatMulF32_304x1024x512_3) { TEST(SimpleTest, MatMulF32_1024x1024x1024_3) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -250,7 +252,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_3) { TEST(SimpleTest, MatMulF32_304x1024x512_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -271,7 +273,7 @@ TEST(SimpleTest, MatMulF32_304x1024x512_4) { TEST(SimpleTest, MatMulF32_1024x1024x1024_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -292,7 +294,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_4) { TEST(SimpleTest, MatMulF32_1026x1024x1024_4) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -313,7 +315,7 @@ TEST(SimpleTest, MatMulF32_1026x1024x1024_4) { TEST(SimpleTest, MatMulF32_304x1024x512_5) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, @@ -331,4 +333,25 @@ TEST(SimpleTest, MatMulF32_304x1024x512_5) { /*profiling*/ true)); } +TEST(SimpleTest, MatMulF16_GPU_256x256x128) { + EnvSetting setting = { + {"DISC_TRANSFORM_SCHEDULE_FILE", + {"kGEMM::GPU:" + c_ft_path + "matmul_nn_s_f16_gpu_schedule.mlir", + false}}, + {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}}, + {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, + {"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}}; + EnvSettingContext ctx(setting); + EXPECT_TRUE(feature_test_main( + /*mlir_file_path*/ c_ft_path + "matmul_nn_s_256x256x128_f16.mlir", + /*backend_types*/ {BackendType::kCuda}, + /*num_inputs*/ 2, + /*num_outputs*/ 1, + /*input_descriptors*/ {"256x128xf16_X", "128x256xf16_X"}, + /*output_descriptors*/ {"f16_X"}, + /*input_vals*/ {}, + /*expected_output_vals*/ {}, + /*profiling*/ true)); +} + } // namespace mlir_test diff --git a/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc b/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc index 81c9a62b1a2..85880870aba 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/matmul_multithread.cc @@ -33,7 +33,7 @@ static bool init_threads = []() { TEST(SimpleMTTest, MatMulF32_111x131x121_Thread_8) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "matmul_multithread_nn_d_f32_schedule.mlir", + {"kGEMM::CPU:" + c_ft_path + "matmul_multithread_nn_d_f32_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}}; EnvSettingContext ctx(setting); @@ -49,7 +49,7 @@ TEST(SimpleMTTest, MatMulF32_111x131x121_Thread_8) { TEST(SimpleTest, MatMulF32_304x1024x256) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + + {"kGEMM::CPU:" + c_ft_path + "matmul_multithread_nn_d_f32_large_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}}; diff --git a/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc index 41a89657ef1..b95a29477ad 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/packed_matmul.cc @@ -33,7 +33,8 @@ static bool init_threads = []() { TEST(PackedMatmul, F32_304x1024x512) { EnvSetting setting = { {"DISC_TRANSFORM_SCHEDULE_FILE", - {"kGEMM::" + c_ft_path + "packed_matmul_nn_p_f32_large_schedule.mlir", + {"kGEMM::CPU:" + c_ft_path + + "packed_matmul_nn_p_f32_large_schedule.mlir", false}}, {"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}, {"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}}, diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc index 9d6b671ba58..2f86b5c8be4 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc @@ -1383,7 +1383,7 @@ SmallVector getDimValues(OpBuilder& b, Location loc, Value v) { SmallVector dims; for (auto en : llvm::enumerate(ty.getShape())) { if (ty.isDynamicDim(en.index())) { - dims.push_back(b.create(loc, v, en.index())); + dims.push_back(b.createOrFold(loc, v, en.index())); } else { dims.push_back(b.create(loc, en.value())); } @@ -3486,9 +3486,8 @@ DiagnosedSilenceableFailure DISCForallToGPUWarpsOp::applyToOne( SmallVector gpuMapping = llvm::to_vector(forallOp.getMapping()->getValue()); if (!llvm::all_of(gpuMapping, [](Attribute map) { - return map.isa(); + return map.isa(); })) { - // TODO: Use thread mapping to indicate warp mapping currently. To use warp // attr after rebase. return mlir::emitDefiniteFailure(target, "gpu warp mapping must be present"); @@ -3560,9 +3559,9 @@ DiagnosedSilenceableFailure DISCSplitReductionSerialOp::applyToOne( Value lhs = matmulOp.getDpsInputOperand(0)->get(); Value rhs = matmulOp.getDpsInputOperand(1)->get(); Value output = matmulOp.getOutputs()[0]; - Value dimM = b.create(loc, lhs, zero); - Value dimN = b.create(loc, rhs, one); - Value dimK = b.create(loc, lhs, one); + Value dimM = b.createOrFold(loc, lhs, zero); + Value dimN = b.createOrFold(loc, rhs, one); + Value dimK = b.createOrFold(loc, lhs, one); scf::ForOp forOp = b.create(loc, zero, dimK, step, ValueRange{output}); @@ -3571,13 +3570,23 @@ DiagnosedSilenceableFailure DISCSplitReductionSerialOp::applyToOne( SmallVector lhsOffsets{zero, iv}; SmallVector lhsDimUppers{dimM, step}; SmallVector lhsStrides{one, one}; - Value lhsSlice = b.create(loc, lhs, lhsOffsets, - lhsDimUppers, lhsStrides); + auto toOpFoldResult = [](Value v) -> OpFoldResult { + auto op = v.getDefiningOp(); + if (!op) return v; + return op.getValue(); + }; + Value lhsSlice = b.createOrFold( + loc, lhs, llvm::to_vector(llvm::map_range(lhsOffsets, toOpFoldResult)), + llvm::to_vector(llvm::map_range(lhsDimUppers, toOpFoldResult)), + llvm::to_vector(llvm::map_range(lhsStrides, toOpFoldResult))); + SmallVector rhsOffsets{iv, zero}; SmallVector rhsDimUppers{step, dimN}; SmallVector rhsStrides{one, one}; - Value rhsSlice = b.create(loc, rhs, rhsOffsets, - rhsDimUppers, rhsStrides); + Value rhsSlice = b.createOrFold( + loc, rhs, llvm::to_vector(llvm::map_range(rhsOffsets, toOpFoldResult)), + llvm::to_vector(llvm::map_range(rhsDimUppers, toOpFoldResult)), + llvm::to_vector(llvm::map_range(rhsStrides, toOpFoldResult))); ShapedType resultType = output.getType().cast(); Value iterArg = forOp.getRegionIterArg(0); linalg::MatmulOp res = b.create( diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc b/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc index 331e9c28a85..d03e9d4bd7e 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc @@ -112,10 +112,6 @@ LogicalResult DiscRewritePayloadIRForRALPass::assignPlacementForFuncOp( } LogicalResult DiscRewritePayloadIRForRALPass::assignPlacement() { - if (gpuEnabled_) - return getOperation()->emitError() - << "not support assign placement info for gpu a.t.m.\n"; - for (FuncOp funcOp : llvm::to_vector<4>(getOperation().getOps())) { if (failed(assignPlacementForFuncOp(funcOp))) return failure(); diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/forall-to-gpu-warps.mlir b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/forall-to-gpu-warps.mlir index 13227d6788d..d0337b887a2 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/forall-to-gpu-warps.mlir +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/forall-to-gpu-warps.mlir @@ -26,7 +26,7 @@ func.func @forall_to_gpu_warps(%arg0: memref<2x2xf16>) { scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c4, %c128) step (%c1, %c1) { scf.forall (%arg3, %arg4) in (%c2, %c2) { memref.store %cst, %arg0[%arg3, %arg4] : memref<2x2xf16> - } {mapping = [#gpu.thread, #gpu.thread]} + } {mapping = [#gpu.warp, #gpu.warp]} } {mapping = "cta-thread-mapping"} return } diff --git a/tao_compiler/mlir/disc/transforms/codegen_utils.cc b/tao_compiler/mlir/disc/transforms/codegen_utils.cc index dc174289ac6..22eb07ddc69 100755 --- a/tao_compiler/mlir/disc/transforms/codegen_utils.cc +++ b/tao_compiler/mlir/disc/transforms/codegen_utils.cc @@ -62,15 +62,15 @@ int getVectorizeOrTileHint(Operation* op) { return attr.getInt(); } -int getThreadPerBlock(Operation* op) { - int thread_per_block = kThreadsRowReduction; +int getCTASize(Operation* op) { + int thread_per_block = kCTASizeDefault; if (!op) return thread_per_block; lmhlo::FusionOp fusion = dyn_cast(op); if (!fusion) { fusion = op->getParentOfType(); } if (!fusion) return thread_per_block; - IntegerAttr attr = fusion->getAttrOfType(kThreadPerBlockHint); + IntegerAttr attr = fusion->getAttrOfType(kCTASizeHint); if (!attr) return thread_per_block; return attr.getInt(); } diff --git a/tao_compiler/mlir/disc/transforms/codegen_utils.h b/tao_compiler/mlir/disc/transforms/codegen_utils.h index 2ebad598e8b..e5c603f11b4 100755 --- a/tao_compiler/mlir/disc/transforms/codegen_utils.h +++ b/tao_compiler/mlir/disc/transforms/codegen_utils.h @@ -88,18 +88,17 @@ using DiscColReductionScheduleType = enum : int { }; // number of therads per block when doing codegen on GPU. -constexpr const char* kThreadPerBlockHint = "disc_thread_per_block_hint"; -// constexpr const char* kThreadPerBlockHint512 = 512; +constexpr const char* kCTASizeHint = "disc_cta_size_hint"; // empirical column size used to choose different row reduce schedule. constexpr const int kRowReductionScheduleTurningSize = 512; // default num of threads per block used when doing codegen #if TENSORFLOW_USE_ROCM -constexpr const int kThreadsRowReduction = 512; +constexpr const int kCTASizeDefault = 512; #else -constexpr const int kThreadsRowReduction = 256; -constexpr const int kThreadsRowReduction512 = 512; +constexpr const int kCTASizeDefault = 256; +constexpr const int kCTASize512 = 512; #endif constexpr const int kVectorizeOrTileSize = 2; @@ -119,7 +118,7 @@ int getRowReductionScheduleHint(Operation* op); int getVectorizeOrTileHint(Operation* op); -int getThreadPerBlock(Operation* op); +int getCTASize(Operation* op); int getColReductionScheduleHint(Operation* op); diff --git a/tao_compiler/mlir/disc/transforms/disc_erase_buffer_deallocation.cc b/tao_compiler/mlir/disc/transforms/disc_erase_buffer_deallocation.cc new file mode 100644 index 00000000000..43848b50d5e --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/disc_erase_buffer_deallocation.cc @@ -0,0 +1,54 @@ +/* Copyright 2023 The BladeDISC Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" +#include "mlir/disc/disc_util.h" +#include "mlir/disc/transforms/PassDetail.h" + +// This file implements logic to erase dealloc op for GPU func op. + +namespace mlir { +namespace disc_ral { + +namespace { + +struct DiscEraseBufferDeallocationPass + : public DiscEraseBufferDeallocationPassBase< + DiscEraseBufferDeallocationPass> { + void runOnOperation() override { + auto funcOp = cast(getOperation()); + + SmallVector deallocOps; + funcOp.walk([&](memref::DeallocOp op) { deallocOps.push_back(op); }); + + for (auto op : deallocOps) { + op->erase(); + } + } +}; + +} // namespace + +std::unique_ptr> +createDiscEraseBufferDeallocationPass() { + return std::make_unique(); +} + +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc b/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc index e09df02aa5a..9bf3aa597b2 100644 --- a/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc @@ -32,6 +32,8 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -39,6 +41,8 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/IR/IRMapping.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/DialectConversion.h" @@ -123,12 +127,15 @@ struct DiscLowerGpuOpsToNVVMOpsPass llvmPatterns.add(converter); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); + populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns); // Put the math conversioin after GpuToNVVM conversions as some math ops // are intended to be converted to nvvm intrinsics. populateMathToLLVMConversionPatterns(converter, llvmPatterns); + memref::populateExpandStridedMetadataPatterns(llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); if (this->hasRedux) populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); diff --git a/tao_compiler/mlir/disc/transforms/disc_passes.td b/tao_compiler/mlir/disc/transforms/disc_passes.td index a9623c00a0e..d8889d1e52e 100644 --- a/tao_compiler/mlir/disc/transforms/disc_passes.td +++ b/tao_compiler/mlir/disc/transforms/disc_passes.td @@ -632,6 +632,10 @@ def DiscTransformLegalizeToLoopPass : Pass<"disc-transform-legalize-to-loop", "m /*default=*/"false", "whether gpu is available.">, Option<"transformFileName_", "transform-file-name", "std::string", /*default=*/"\"\"", "Filename of the transform schedule.">, + Option<"cc_major_", "gpu-sm-cc-major", "int", + /*default=*/"8", "gpu sm cc_major.">, + Option<"cc_minor_", "gpu-sm-cc-minor", "int", + /*default=*/"0", "gpu sm cc_minor.">, Option<"enableExpensiveChecks_", "enable-expensive-checks", "bool", /*default=*/"false", "perform expensive checks to better report errors in the transform IR.">, ]; @@ -649,3 +653,8 @@ def DiscDuplicateComputationAfterFusionPass : Pass<"disc-duplicate-computation-a "memref::MemRefDialect", ]; } + +def DiscEraseBufferDeallocationPass : Pass<"disc-erase-buffer-deallocation", "mlir::gpu::GPUFuncOp"> { + let summary = "Erase dealloc op for GPU func ops."; + let constructor = "createDiscEraseBufferDeallocationPass()"; +} diff --git a/tao_compiler/mlir/disc/transforms/disc_specialize_fusion_with_speculation.cc b/tao_compiler/mlir/disc/transforms/disc_specialize_fusion_with_speculation.cc index a2ca0d3bb0a..18fa6cb295b 100644 --- a/tao_compiler/mlir/disc/transforms/disc_specialize_fusion_with_speculation.cc +++ b/tao_compiler/mlir/disc/transforms/disc_specialize_fusion_with_speculation.cc @@ -389,7 +389,7 @@ struct DiscSpecializeFusionWithSpeculationPass Value operand = reduce_op->getOperand(0); // TODO(disc): Use 256 as default block size; turn this number for // different shapes - int block_size = kThreadsRowReduction; + int block_size = kCTASizeDefault; Value col_size = b.create(loc, operand, 1); Value pred; @@ -443,11 +443,11 @@ struct DiscSpecializeFusionWithSpeculationPass auto first_schedule = b.getIntegerAttr(b.getIntegerType(32), 1); auto second_schedule = b.getIntegerAttr(b.getIntegerType(32), 2); auto num_thread_attr = b.getIntegerAttr(b.getIntegerType(32), block_size); - fusion_op->setAttr(kThreadPerBlockHint, num_thread_attr); + fusion_op->setAttr(kCTASizeHint, num_thread_attr); fusion_op->setAttr(kRowReductionScheduleHint, first_schedule); // one block one row addFusionTag(b, fusion_op, "1b1r"); - cloned->setAttr(kThreadPerBlockHint, num_thread_attr); + cloned->setAttr(kCTASizeHint, num_thread_attr); cloned->setAttr(kRowReductionScheduleHint, second_schedule); // one warp one row addFusionTag(b, cloned, "1w1r"); @@ -495,7 +495,7 @@ struct DiscSpecializeFusionWithSpeculationPass Value col_size = b.create(loc, operand, 1); Value matrix_size = b.create(loc, row_size, col_size); - int thread_per_block = kThreadsRowReduction; + int thread_per_block = kCTASizeDefault; Value cur_threads = b.create(loc, thread_per_block); // b.create(loc, max_threads_per_block_); Value cur_blocks = @@ -517,15 +517,15 @@ struct DiscSpecializeFusionWithSpeculationPass b.getIntegerAttr(b.getIntegerType(32), DISC_BLOCK_TILE_H64); // block-size is 256 in the second schedule auto num_thread_full_attr256 = - b.getIntegerAttr(b.getIntegerType(32), kThreadsRowReduction); + b.getIntegerAttr(b.getIntegerType(32), kCTASizeDefault); // block-size is 512 in the first schedule auto num_thread_full_attr512 = - b.getIntegerAttr(b.getIntegerType(32), kThreadsRowReduction512); - fusion_op->setAttr(kThreadPerBlockHint, num_thread_full_attr512); + b.getIntegerAttr(b.getIntegerType(32), kCTASize512); + fusion_op->setAttr(kCTASizeHint, num_thread_full_attr512); fusion_op->setAttr(kColReductionScheduleHint, first_schedule); // use fisrt schedule if row_size < col_size addFusionTag(b, fusion_op, "thread_tile_h32"); - cloned->setAttr(kThreadPerBlockHint, num_thread_full_attr256); + cloned->setAttr(kCTASizeHint, num_thread_full_attr256); cloned->setAttr(kColReductionScheduleHint, second_schedule); // use second schedule if row_size >= col_size addFusionTag(b, cloned, "block_tile_h64"); @@ -600,7 +600,7 @@ struct DiscSpecializeFusionWithSpeculationPass if (fusion_type == FusionType::kRowReduction || fusion_type == FusionType::kStitch) { Operation* dominant_equivalent_op = GetCandidateRowReduceOp(fusion_op); - auto block_size = getThreadPerBlock(fusion_op.getOperation()); + auto block_size = getCTASize(fusion_op.getOperation()); int rowred_schedule = getRowReductionScheduleHint(fusion_op.getOperation()); diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc index 9beb7652b02..3f5e0508697 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_legalize_to_loop.cc @@ -98,11 +98,14 @@ struct DiscTransformLegalizeToLoopPass DiscTransformLegalizeToLoopPass> { explicit DiscTransformLegalizeToLoopPass(bool gpuEnabled, const std::string& transformFileName, + int cc_major, int cc_minor, bool enableExpensiveChecks) : DiscTransformLegalizeToLoopPassBase:: DiscTransformLegalizeToLoopPassBase() { this->gpuEnabled_ = gpuEnabled; this->transformFileName_ = transformFileName; + this->cc_major_ = cc_major; + this->cc_minor_ = cc_minor; this->enableExpensiveChecks_ = enableExpensiveChecks; } @@ -116,10 +119,14 @@ struct DiscTransformLegalizeToLoopPass ShapeAnalysis& shapeAnalysis, ScheduleDispatcher& scheduleDispatcher); + LogicalResult handleGpuFusionOp(OpBuilder& b, Operation* fusion, + ShapeAnalysis& shapeAnalysis, + ScheduleDispatcher& scheduleDispatcher); + // Inject schedule selection logic LogicalResult injectScheduleSelectionIR( OpBuilder& b, PatternDescription& pd, - SmallVectorImpl& clonedOps); + SmallVectorImpl& clonedOps, DeviceType deviceType); // Outlines the fusion op to a standalone module op. LogicalResult outlineFusionOp(lmhlo::FusionOp fusionOp, @@ -131,6 +138,11 @@ struct DiscTransformLegalizeToLoopPass LogicalResult inlineTransformedModule(OpBuilder& b, Operation* fusion, FusionPattern& fusionPattern, ModuleOp m); + + private: + // GPU compute capatility numbers. + int cc_major_; + int cc_minor_; }; LogicalResult DiscTransformLegalizeToLoopPass::outlineFusionOp( @@ -235,7 +247,7 @@ LogicalResult DiscTransformLegalizeToLoopPass::inlineTransformedModule( LogicalResult DiscTransformLegalizeToLoopPass::injectScheduleSelectionIR( OpBuilder& b, PatternDescription& pd, - SmallVectorImpl& clonedOps) { + SmallVectorImpl& clonedOps, DeviceType deviceType) { auto fusionOp = pd.getFusionOp(); auto factories = ScheduleFactoryRegistry::get().getAllCandidateScheduleFactories(pd); @@ -262,8 +274,9 @@ LogicalResult DiscTransformLegalizeToLoopPass::injectScheduleSelectionIR( PatternDescription clonedPd(cloned, fusionPattern, pd.getShapeAnalysis()); Value pred; if (failed(factories[i]->buildGuardCondition(b, cloned->getLoc(), clonedPd, - pred))) + pred))) { return cloned->emitError() << "faield to build guard IR\n"; + } auto ifOp = b.create(cloned->getLoc(), TypeRange{}, pred, true); cloned->moveBefore(ifOp.thenBlock(), ifOp.thenBlock()->begin()); @@ -309,7 +322,93 @@ LogicalResult DiscTransformLegalizeToLoopPass::handleCpuFusionOp( // clone the fusion op, each for one candidate schedule. SmallVector clonedFusionOps; PatternDescription pd(fusionOp, fusionPattern, shapeAnalysis); - if (failed(injectScheduleSelectionIR(b, pd, clonedFusionOps))) { + if (failed(injectScheduleSelectionIR(b, pd, clonedFusionOps, + DeviceType::kCPU))) { + return fusionOp->emitError() << "failed to injectScheduleSelectionIR\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After injectScheduleSelectionIR:\n" + << fusion->getParentOfType() << "\n"); + + for (auto fusion : clonedFusionOps) { + b.setInsertionPoint(fusion); + auto fusionOp = cast(fusion); + FusionPattern fusionPattern(fusionOp, &shapeAnalysis); + // 1, Outline the fusion to a standalone module op. + OwningOpRef m; + if (failed(outlineFusionOp(fusionOp, fusionPattern, m))) { + return fusionOp->emitError() << "failed to outlineFusionOp\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After outline fusion op:\n" << m.get() << "\n"); + + // 2, assign a default schedule for each pattern here. + PatternDescription patternDescription(fusionOp, fusionPattern, + shapeAnalysis); + if (failed(scheduleDispatcher.dispatch(patternDescription, m.get()))) { + return fusionOp->emitError() << "failed to assignSchedule\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After assign schedule for fusion op:\n" + << m.get() << "\n"); + + // 3, Build a nested pass pipeline to legalize the outlined fusion op. + if (failed(runTransformPipeline(m.get()))) { + return fusionOp->emitError() << "failed to run runTransformPipeline\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After run transform pipeline:\n" + << m.get() << "\n"); + + // 4, Inline the lowered IR into the orignal module. + if (failed(inlineTransformedModule(b, fusion, fusionPattern, m.get()))) { + return fusion->emitError() << "failed to inlineTransformedModule\n"; + } + LLVM_DEBUG(llvm::dbgs() << "After inline transformed module:\n" + << *fusion << "\n"); + } + + return success(); +} + +LogicalResult DiscTransformLegalizeToLoopPass::handleGpuFusionOp( + OpBuilder& b, Operation* fusion, ShapeAnalysis& shapeAnalysis, + ScheduleDispatcher& scheduleDispatcher) { + b.setInsertionPoint(fusion); + auto fusionOp = cast(fusion); + assert(fusionOp); + FusionPattern fusionPattern(fusionOp, &shapeAnalysis); + if (!fusionPattern.isTransformBasedFusion()) { + // skip non-transform-based fusion pattern. + return success(); + } + + // GPU GEMM uses block size 128. + auto ctaSizeAttr = b.getIntegerAttr(b.getIntegerType(32), 128); + fusionOp->setAttr(kCTASizeHint, ctaSizeAttr); + + auto& bypassMap = bypassCodegenPatternNameMap(); + auto it = bypassMap.find(getFusionName(fusionOp).str()); + if (it != bypassMap.end()) { + OwningOpRef m; + if (failed(parseTransformModuleFromFile(b.getContext(), it->second, m))) { + llvm::dbgs() << "illegal bypass transform fusion pattern codegen " + "setting, unable to load module from: " + << it->second << "\n"; + return failure(); + } + // Inline the lowered IR into the orignal module. + if (failed(inlineTransformedModule(b, fusion, fusionPattern, m.get()))) { + return fusion->emitError() + << "failed to inline module load from bypass setting\n"; + } + return success(); + } + + // 0, inject schedule selection logic + // clone the fusion op, each for one candidate schedule. + SmallVector clonedFusionOps; + PatternDescription pd(fusionOp, fusionPattern, shapeAnalysis); + // TODO: The schedule selection IR will not rely on the guard function, but + // with a holistic switch structure. + if (failed(injectScheduleSelectionIR(b, pd, clonedFusionOps, + DeviceType::kGPU))) { return fusionOp->emitError() << "failed to injectScheduleSelectionIR\n"; } LLVM_DEBUG(llvm::dbgs() << "After injectScheduleSelectionIR:\n" @@ -374,6 +473,10 @@ void DiscTransformLegalizeToLoopPass::runOnOperation() { // Assign a transform schedule for the given fusion pattern. ScheduleDispatcher scheduleDispatcher{transformFileName_}; + DeviceInfo deviceInfo; + deviceInfo.cc_major = cc_major_; + deviceInfo.cc_minor = cc_minor_; + scheduleDispatcher.setDeviceInfo(deviceInfo); if (failed(scheduleDispatcher.parseModuleFromFile(b.getContext()))) { func->emitError() << "failed to parse transform module form " << transformFileName_ << " .\n"; @@ -381,8 +484,18 @@ void DiscTransformLegalizeToLoopPass::runOnOperation() { } for (Operation* fusion : gpu_fusion_worklist) { - // TODO(disc): handling stitch fusion on GPU. - return signalPassFailure(); + if (!useShapeConstraintIR()) { + // TODO: use FuncOp that contains `fusionOp` to construct + // shape-analysis, which will use global information for shape equality + // and decomposition analysis. + shapeAnalysisPtr.reset(new ShapeAnalysisDeprecated{fusion}); + } + + // Error message should be emitted inside the function. + if (failed(handleGpuFusionOp(b, fusion, *shapeAnalysisPtr, + scheduleDispatcher))) { + return signalPassFailure(); + } } for (Operation* fusion : cpu_fusion_worklist) { @@ -405,10 +518,10 @@ void DiscTransformLegalizeToLoopPass::runOnOperation() { std::unique_ptr> createDiscTransformLegalizeToLoopPass(bool gpuEnabled, - const std::string& filename, - bool expensiveCheck) { - return std::make_unique(gpuEnabled, filename, - expensiveCheck); + const std::string& filename, int cc_major, + int cc_minor, bool expensiveCheck) { + return std::make_unique( + gpuEnabled, filename, cc_major, cc_minor, expensiveCheck); } } // namespace disc_ral diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc index e0770ac3d39..95cad8156a1 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -74,6 +75,11 @@ std::unordered_map& getStringToPatternKindMap() { return stringToPatternKindMap; } +std::unordered_map& getStringToDeviceTypeMap() { + static std::unordered_map stringToDeviceTypeMap; + return stringToDeviceTypeMap; +} + bool PatternKindAndStringMapRegistrar = []() { auto& patternKindToStringMap = getPatternKindToStringMap(); auto& stringToPatternKindMap = getStringToPatternKindMap(); @@ -82,6 +88,9 @@ bool PatternKindAndStringMapRegistrar = []() { for (auto& pair : patternKindToStringMap) { stringToPatternKindMap[pair.second] = pair.first; } + auto& stringToDeviceTypeMap = getStringToDeviceTypeMap(); + stringToDeviceTypeMap.emplace("CPU", DeviceType::kCPU); + stringToDeviceTypeMap.emplace("GPU", DeviceType::kGPU); return true; }(); @@ -94,13 +103,14 @@ using transform::TileToForallOp; using transform::VectorizeOp; MatchOp buildMatchOp(OpBuilder& b, Location& loc, Value target, - ArrayRef ops, StringRef name = {}) { + ArrayRef ops, StringRef name = {}, + DictionaryAttr givenAttrs = nullptr) { ArrayAttr opNames; if (!ops.empty()) { opNames = b.getStrArrayAttr(ops); } - DictionaryAttr attrs; - if (!name.empty()) { + DictionaryAttr attrs = givenAttrs; + if (!name.empty() && givenAttrs == nullptr) { attrs = b.getDictionaryAttr( b.getNamedAttr(kDISCLinalgTransformName, b.getStringAttr(name))); } @@ -111,54 +121,69 @@ MatchOp buildMatchOp(OpBuilder& b, Location& loc, Value target, } TileToForallOp buildTileToForallOp(OpBuilder& b, Location& loc, Value target, - ArrayRef numThreads) { - return b.create(loc, target, numThreads, - transform::NumThreadsSpec(), ArrayAttr{}); + ArrayRef threads, + transform::NumThreadsSpec numThreadsSpec, + ArrayAttr mapping) { + return b.create(loc, target, threads, numThreadsSpec, + mapping); +} + +TileToForallOp buildTileToForallOp(OpBuilder& b, Location& loc, Value target, + ArrayRef tiles, + transform::TileSizesSpec tileSizesSpec, + ArrayAttr mapping) { + return b.create(loc, target, tiles, tileSizesSpec, mapping); } Value buildFuseIntoContainingOp(OpBuilder& b, Location& loc, Value target, Value anchor) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - SmallVector resultTypes{pdlType, pdlType}; + auto transformOpType = transform::AnyOpType::get(b.getContext()); + SmallVector resultTypes{transformOpType, transformOpType}; return b.create(loc, resultTypes, target, anchor) .getFusedOp(); } FuseOp buildFuseOp(OpBuilder& b, Location& loc, Value target, ArrayRef tileSizes, ArrayRef interchange) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); SmallVector loopTypes; - for (int64_t tileSize : tileSizes) - if (tileSize) loopTypes.push_back(pdlType); - return b.create(loc, pdlType, loopTypes, target, + for (int64_t tileSize : tileSizes) { + if (tileSize) { + loopTypes.push_back(transformOpType); + } + } + return b.create(loc, transformOpType, loopTypes, target, b.getI64ArrayAttr(tileSizes), b.getI64ArrayAttr(interchange)); } TileOp buildTileOp(OpBuilder& b, Location& loc, Value target, ArrayRef tileSizes, ArrayRef interchange) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); SmallVector loopTypes; - for (int64_t tileSize : tileSizes) - if (tileSize) loopTypes.push_back(pdlType); - return b.create(loc, pdlType, loopTypes, target, ValueRange{}, + for (int64_t tileSize : tileSizes) { + if (tileSize) { + loopTypes.push_back(transformOpType); + } + } + return b.create(loc, transformOpType, loopTypes, target, ValueRange{}, tileSizes, interchange); } transform_dialect::ApplyPatternsOp buildRunCanonicalizer(OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, - true); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, + target, true); } transform::GetProducerOfOperand buildGetProducerOfOperand(OpBuilder& b, Location& loc, Value target, int64_t operandIdx) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, target, operandIdx); } @@ -172,14 +197,14 @@ transform::PadOp buildPadOp(OpBuilder& b, Location& loc, Value target, ArrayRef paddingDimensions, int64_t numOperands, ArrayRef paddingTypes = {}) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); // TODO(wyzero): support other types. SmallVector paddingAttrs(numOperands, b.getZeroAttr(b.getF32Type())); for (const auto& [idx, type] : llvm::enumerate(paddingTypes)) { paddingAttrs[idx] = b.getZeroAttr(type); } - return b.create(loc, pdlType, target, + return b.create(loc, transformOpType, target, b.getArrayAttr(paddingAttrs), b.getI64ArrayAttr(paddingDimensions), ArrayAttr{}, ArrayAttr{}, ArrayAttr{}); @@ -187,8 +212,9 @@ transform::PadOp buildPadOp(OpBuilder& b, Location& loc, Value target, transform::GetParentForOp buildGetParentForOp(OpBuilder& b, Location& loc, Value target, int64_t num_loops) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, num_loops); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, target, + num_loops); } transform_dialect::CacheReadOp buildCacheRead(OpBuilder& b, Location& loc, @@ -208,15 +234,17 @@ transform_dialect::LowerMultiLevelPackToLoopOp buildLowerMultiLevelPackToLoop( VectorizeOp buildVectorize(OpBuilder& b, Location& loc, Value target, bool vectorizePad) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, vectorizePad); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, target, vectorizePad); } transform_dialect::DISCBufferizeOp buildDISCBufferize(OpBuilder& b, Location& loc, - Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target); + Value target, + bool targetGpu) { + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, + target, targetGpu); } vector::LowerVectorsOptions getDefaultLowerVectorsOptions() { @@ -236,9 +264,9 @@ transform_dialect::DISCLowerVectorsOp buildLowerVectors( OpBuilder& b, Location& loc, Value target, const vector::LowerVectorsOptions& options = getDefaultLowerVectorsOptions()) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, - options); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, + target, options); } SplitHandleOp buildSplitHandleOp(OpBuilder& b, Location& loc, Value target, @@ -249,91 +277,93 @@ SplitHandleOp buildSplitHandleOp(OpBuilder& b, Location& loc, Value target, transform_dialect::InlineReductionInitializerOp buildInlineReductionInitializerOp(OpBuilder& b, Location& loc, Value initOp, Value loopOp, Value readerOp) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); return b.create( - loc, pdlType, initOp, loopOp, readerOp); + loc, transformOpType, initOp, loopOp, readerOp); } transform_dialect::DecomposeVectorsOp buildDecomposeVectors( OpBuilder& b, Location& loc, Value target, int64_t vectorSize) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, target, - vectorSize); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create(loc, transformOpType, + target, vectorSize); } transform_dialect::LinalgFuseProducersOp buildLinalgFuseProducersOp( OpBuilder& b, Location& loc, Value target, ValueRange producers) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, - target, producers); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create( + loc, transformOpType, target, producers); } transform_dialect::ReplaceConstPaddingValueOp buildReplaceConstPaddingValueOp( OpBuilder& b, Location& loc, Value target, StringRef mode) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, - target, mode); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create( + loc, transformOpType, target, mode); } transform_dialect::ConvertPaddingPlaceholderToConstOp buildConvertPaddingPlaceholderToConstOp(OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); return b.create( - loc, pdlType, target); + loc, transformOpType, target); } transform_dialect::LinalgEagerlyBackwardInitTensorOp buildLinalgEagerlyBackwardInitTensorOp(OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); return b.create( - loc, pdlType, target); + loc, transformOpType, target); } transform_dialect::DISCFuseIntoContainingOp buildDISCFuseIntoContainingOp( OpBuilder& b, Location& loc, Value target, Value anchor) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, - target, anchor); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create( + loc, transformOpType, target, anchor); } transform_dialect::ReductionOutputFuseOp buildReductionOutputFuseOp( OpBuilder& b, Location& loc, Value target, Value loop) { - SmallVector pdlTypes(2, transform::AnyOpType::get(b.getContext())); - return b.create(loc, pdlTypes, - target, loop); + SmallVector transformOpTypes(2, + transform::AnyOpType::get(b.getContext())); + return b.create( + loc, transformOpTypes, target, loop); } transform_dialect::ReductionInputFuseOp buildReductionInputFuseOp(OpBuilder& b, Location& loc, Value target, Value loop) { - SmallVector pdlTypes(2, transform::AnyOpType::get(b.getContext())); - return b.create(loc, pdlTypes, - target, loop); + SmallVector transformOpTypes(2, + transform::AnyOpType::get(b.getContext())); + return b.create( + loc, transformOpTypes, target, loop); } transform_dialect::VectorizeConditionalGenericOp buildVectorizeConditionalGenericOp(OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); return b.create( - loc, pdlType, target); + loc, transformOpType, target); } transform_dialect::SplitVectorTransferIntoFullAndPartialOp buildSplitVectorTransferIntoFullAndPartialOp(OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); + auto transformOpType = transform::AnyOpType::get(b.getContext()); return b.create( - loc, pdlType, target); + loc, transformOpType, target); } transform_dialect::LowerConditionalGenericOp buildLowerConditionalGenericOp( OpBuilder& b, Location& loc, Value target) { - auto pdlType = transform::AnyOpType::get(b.getContext()); - return b.create(loc, pdlType, - target); + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create( + loc, transformOpType, target); } transform_dialect::ApplyCommonSubexpressionEliminationOp buildCSEOp( @@ -355,12 +385,70 @@ transform_dialect::ApplyLoopIndependentCodeMotionOp buildLICMOp(OpBuilder& b, target); } +transform_dialect::DISCPromoteDotOperandsOp buildPromoteDotOperandsOp( + OpBuilder& b, Location& loc, Value target, ArrayRef indices) { + SmallVector transformOpTypes(3, + transform::AnyOpType::get(b.getContext())); + return b.create( + loc, transformOpTypes, target, indices); +} + +transform_dialect::DISCSplitReductionSerialOp buildSplitReductionSerialOp( + OpBuilder& b, Location& loc, Value target, ArrayRef tileSizes) { + SmallVector transformOpTypes(2, + transform::AnyOpType::get(b.getContext())); + return b.create( + loc, transformOpTypes, target, tileSizes); +} + +transform_dialect::DISCVectorToMMAConversionOp buildVectorToMMAConversionOp( + OpBuilder& b, Location& loc, Value target) { + return b.create(loc, target); +} + +transform_dialect::DISCForallToGPUCTAsOp buildForallToGPUCTAsOp(OpBuilder& b, + Location& loc, + Value target) { + auto transformOpType = transform::AnyOpType::get(b.getContext()); + return b.create( + loc, transformOpType, target); +} + +transform_dialect::DISCForallToGPUWarpsOp buildForallToGPUWarpsOp( + OpBuilder& b, Location& loc, Value target) { + return b.create(loc, target); +} + +transform_dialect::DISCLowerGmemToSmemOp buildLowerGmemToSmemOp(OpBuilder& b, + Location& loc, + Value target) { + return b.create(loc, target); +} + +transform_dialect::DISCTransferWriteZeroToSCFOp buildTransferWriteZeroToSCFOp( + OpBuilder& b, Location& loc, Value target) { + return b.create(loc, target); +} + +transform_dialect::DISCEraseDeallocOp buildEraseDeallocOp(OpBuilder& b, + Location& loc, + Value target) { + return b.create(loc, target); +} + +transform_dialect::DISCInlineAndConvertGPUIdsOp buildInlineAndConvertGPUIdsOp( + OpBuilder& b, Location& loc, Value target) { + return b.create(loc, target); +} + class ParsedFromFileScheduleFactory : public ScheduleFactoryWithNoGuard { public: explicit ParsedFromFileScheduleFactory(int64_t id, PatternKind kind, ArrayRef tags, + DeviceType deviceType, ModuleOp transformModule); - LogicalResult assignSchedule(PatternDescription&, ModuleOp) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; private: ModuleOp transformModule_; @@ -368,12 +456,12 @@ class ParsedFromFileScheduleFactory : public ScheduleFactoryWithNoGuard { ParsedFromFileScheduleFactory::ParsedFromFileScheduleFactory( int64_t id, PatternKind kind, ArrayRef tags, - ModuleOp transformModule) - : ScheduleFactoryWithNoGuard(id, kind, tags), + DeviceType deviceType, ModuleOp transformModule) + : ScheduleFactoryWithNoGuard(id, kind, tags, deviceType), transformModule_(transformModule) {} LogicalResult ParsedFromFileScheduleFactory::assignSchedule( - PatternDescription& pd, ModuleOp m) { + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { OpBuilder b(m); for (auto& op : transformModule_.getBody()->getOperations()) { if (!isa(&op)) continue; @@ -393,16 +481,22 @@ class Aarch64GEMMDefaultScheduleFactory : public ScheduleFactoryWithNoGuard { public: using ScheduleFactoryWithNoGuard::ScheduleFactoryWithNoGuard; bool checkFusionPatternProperties(PatternDescription&) override; - LogicalResult assignSchedule(PatternDescription&, ModuleOp) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; }; // TODO(wyzero): merge default schedule and default with epilogue schedule. bool Aarch64GEMMDefaultScheduleFactory::checkFusionPatternProperties( PatternDescription& pd) { + if (!ScheduleFactory::checkFusionPatternProperties(pd)) { + return false; + } auto& fusionPattern = pd.getFusionPattern(); auto& rootOps = fusionPattern.getRootOps(); // Only support single output a.t.m. - if (rootOps.size() != 1) return false; + if (rootOps.size() != 1) { + return false; + } // This schedule not support epilogue fusion auto dominantOp = fusionPattern.getDominantOp(); @@ -410,15 +504,15 @@ bool Aarch64GEMMDefaultScheduleFactory::checkFusionPatternProperties( } LogicalResult Aarch64GEMMDefaultScheduleFactory::assignSchedule( - PatternDescription& pd, ModuleOp m) { + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { OpBuilder b(m); b.setInsertionPointToStart(&m.getBodyRegion().front()); Location loc = m.getLoc(); MLIRContext* ctx = m->getContext(); - auto pdlOpType = transform::AnyOpType::get(ctx); + auto transformOpType = transform::AnyOpType::get(ctx); auto seqOp = b.create( - loc, TypeRange{}, transform::FailurePropagationMode::Propagate, pdlOpType, - [&](OpBuilder& b, Location loc, Value variantH) {}); + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, + transformOpType, [&](OpBuilder& b, Location loc, Value variantH) {}); auto& bodyBlock = seqOp.getBody().front(); b.setInsertionPointToStart(&bodyBlock); Value variant = bodyBlock.getArgument(0); @@ -460,7 +554,8 @@ LogicalResult Aarch64GEMMDefaultScheduleFactory::assignSchedule( Value matmul = matmulSplitOp->getResult(1); // transform.structured.tile_to_forall_op %matmul num_threads [1, 1] - auto forallOp = buildTileToForallOp(b, loc, matmul, {1, 1}); + auto forallOp = buildTileToForallOp(b, loc, matmul, {1, 1}, + transform::NumThreadsSpec(), ArrayAttr{}); Value forallLoop = forallOp->getResult(0); Value tiledMatmul = forallOp->getResult(1); @@ -589,7 +684,7 @@ LogicalResult Aarch64GEMMDefaultScheduleFactory::assignSchedule( buildVectorize(b, loc, func, true); variant = buildRunCanonicalizer(b, loc, variant); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); variant = buildLowerVectors(b, loc, variant); b.create(loc); @@ -601,11 +696,15 @@ class Aarch64GEMMDefaultScheduleWithEpilogueFactory public: using ScheduleFactoryWithNoGuard::ScheduleFactoryWithNoGuard; bool checkFusionPatternProperties(PatternDescription&) override; - LogicalResult assignSchedule(PatternDescription&, ModuleOp) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; }; bool Aarch64GEMMDefaultScheduleWithEpilogueFactory:: checkFusionPatternProperties(PatternDescription& pd) { + if (!ScheduleFactory::checkFusionPatternProperties(pd)) { + return false; + } auto& fusionPattern = pd.getFusionPattern(); auto& rootOps = fusionPattern.getRootOps(); // Only support single output a.t.m. @@ -617,15 +716,15 @@ bool Aarch64GEMMDefaultScheduleWithEpilogueFactory:: } LogicalResult Aarch64GEMMDefaultScheduleWithEpilogueFactory::assignSchedule( - PatternDescription& pd, ModuleOp m) { + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { OpBuilder b(m); b.setInsertionPointToStart(&m.getBodyRegion().front()); Location loc = m.getLoc(); MLIRContext* ctx = m->getContext(); - auto pdlOpType = transform::AnyOpType::get(ctx); + auto transformOpType = transform::AnyOpType::get(ctx); auto seqOp = b.create( - loc, TypeRange{}, transform::FailurePropagationMode::Propagate, pdlOpType, - [&](OpBuilder& b, Location loc, Value variantH) {}); + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, + transformOpType, [&](OpBuilder& b, Location loc, Value variantH) {}); auto& bodyBlock = seqOp.getBody().front(); b.setInsertionPointToStart(&bodyBlock); Value variant = bodyBlock.getArgument(0); @@ -679,7 +778,8 @@ LogicalResult Aarch64GEMMDefaultScheduleWithEpilogueFactory::assignSchedule( rootHandle = buildMatchOp(b, loc, variant, {}, nameMap[rootOp]); } - auto forallOp = buildTileToForallOp(b, loc, rootHandle, {1, 1}); + auto forallOp = buildTileToForallOp(b, loc, rootHandle, {1, 1}, + transform::NumThreadsSpec(), ArrayAttr{}); Value forallLoop = forallOp->getResult(0); rootHandle = forallOp->getResult(1); @@ -718,7 +818,7 @@ LogicalResult Aarch64GEMMDefaultScheduleWithEpilogueFactory::assignSchedule( // TODO(wyzero): finetune the schedule for small m or n buildTileOp(b, loc, matmul, {1, 1, 1}, {0, 2, 1}); variant = buildRunCanonicalizer(b, loc, variant); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); b.create(loc); return success(); } @@ -840,7 +940,7 @@ LogicalResult Aarch64GEMMDefaultScheduleWithEpilogueFactory::assignSchedule( auto placeholderOps = buildMatchOp( b, loc, variant, {"disc_linalg_ext.padding_value_placeholder"}, {}); buildConvertPaddingPlaceholderToConstOp(b, loc, placeholderOps); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); variant = buildLowerVectors(b, loc, variant); // de-compose large size vector operations @@ -853,13 +953,17 @@ class Aarch64GEMMLargeKScheduleFactory : public ScheduleFactory { public: using ScheduleFactory::ScheduleFactory; bool checkFusionPatternProperties(PatternDescription&) override; - LogicalResult assignSchedule(PatternDescription&, ModuleOp) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; LogicalResult buildGuardCondition(OpBuilder& b, Location loc, PatternDescription&, Value&) override; }; bool Aarch64GEMMLargeKScheduleFactory::checkFusionPatternProperties( PatternDescription& pd) { + if (!ScheduleFactory::checkFusionPatternProperties(pd)) { + return false; + } auto& fusionPattern = pd.getFusionPattern(); auto& rootOps = fusionPattern.getRootOps(); // Only support single output a.t.m. @@ -898,15 +1002,15 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::buildGuardCondition( } LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( - PatternDescription& pd, ModuleOp m) { + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { OpBuilder b(m); b.setInsertionPointToStart(&m.getBodyRegion().front()); Location loc = m.getLoc(); MLIRContext* ctx = m->getContext(); - auto pdlOpType = transform::AnyOpType::get(ctx); + auto transformOpType = transform::AnyOpType::get(ctx); auto seqOp = b.create( - loc, TypeRange{}, transform::FailurePropagationMode::Propagate, pdlOpType, - [&](OpBuilder& b, Location loc, Value variantH) {}); + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, + transformOpType, [&](OpBuilder& b, Location loc, Value variantH) {}); auto& bodyBlock = seqOp.getBody().front(); b.setInsertionPointToStart(&bodyBlock); Value variant = bodyBlock.getArgument(0); @@ -948,7 +1052,8 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( Value matmul = matmulSplitOp->getResult(1); // transform.structured.tile_to_forall_op %matmul num_threads [1, 1] - auto forallOp = buildTileToForallOp(b, loc, matmul, {1, 1}); + auto forallOp = buildTileToForallOp(b, loc, matmul, {1, 1}, + transform::NumThreadsSpec(), ArrayAttr{}); Value forallLoop = forallOp->getResult(0); Value tiledMatmul = forallOp->getResult(1); @@ -1095,7 +1200,7 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( buildVectorize(b, loc, func, true); variant = buildRunCanonicalizer(b, loc, variant); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); if (!k0Skipped && !k1Skipped) { Value leftFillOp = buildMatchOp(b, loc, variant, {}, nameMap[dotOp]); @@ -1121,11 +1226,15 @@ class Aarch64GEMMLargeKScheduleWithEpilogueFactory public: using Aarch64GEMMLargeKScheduleFactory::Aarch64GEMMLargeKScheduleFactory; bool checkFusionPatternProperties(PatternDescription&) override; - LogicalResult assignSchedule(PatternDescription&, ModuleOp) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; }; bool Aarch64GEMMLargeKScheduleWithEpilogueFactory::checkFusionPatternProperties( PatternDescription& pd) { + if (!ScheduleFactory::checkFusionPatternProperties(pd)) { + return false; + } auto& fusionPattern = pd.getFusionPattern(); auto& rootOps = fusionPattern.getRootOps(); // Only support single output a.t.m. @@ -1136,15 +1245,15 @@ bool Aarch64GEMMLargeKScheduleWithEpilogueFactory::checkFusionPatternProperties( } LogicalResult Aarch64GEMMLargeKScheduleWithEpilogueFactory::assignSchedule( - PatternDescription& pd, ModuleOp m) { + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { OpBuilder b(m); b.setInsertionPointToStart(&m.getBodyRegion().front()); Location loc = m.getLoc(); MLIRContext* ctx = m->getContext(); - auto pdlOpType = transform::AnyOpType::get(ctx); + auto transformOpType = transform::AnyOpType::get(ctx); auto seqOp = b.create( - loc, TypeRange{}, transform::FailurePropagationMode::Propagate, pdlOpType, - [&](OpBuilder& b, Location loc, Value variantH) {}); + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, + transformOpType, [&](OpBuilder& b, Location loc, Value variantH) {}); auto& bodyBlock = seqOp.getBody().front(); b.setInsertionPointToStart(&bodyBlock); Value variant = bodyBlock.getArgument(0); @@ -1199,7 +1308,8 @@ LogicalResult Aarch64GEMMLargeKScheduleWithEpilogueFactory::assignSchedule( } rootHandle = buildLinalgEagerlyBackwardInitTensorOp(b, loc, rootHandle); - auto forallOp = buildTileToForallOp(b, loc, rootHandle, {1, 1}); + auto forallOp = buildTileToForallOp(b, loc, rootHandle, {1, 1}, + transform::NumThreadsSpec(), ArrayAttr{}); Value forallLoop = forallOp->getResult(0); rootHandle = forallOp->getResult(1); @@ -1262,7 +1372,7 @@ LogicalResult Aarch64GEMMLargeKScheduleWithEpilogueFactory::assignSchedule( if (m1Skipped || n1Skipped || k1Skipped) { buildTileOp(b, loc, matmul, {1, 1, 1}, {0, 2, 1}); variant = buildRunCanonicalizer(b, loc, variant); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); b.create(loc); return success(); } @@ -1381,7 +1491,7 @@ LogicalResult Aarch64GEMMLargeKScheduleWithEpilogueFactory::assignSchedule( buildVectorize(b, loc, func, true); variant = buildRunCanonicalizer(b, loc, variant); - variant = buildDISCBufferize(b, loc, variant); + variant = buildDISCBufferize(b, loc, variant, false); variant = buildRunCanonicalizer(b, loc, variant); Value conditionalOps = @@ -1395,24 +1505,276 @@ LogicalResult Aarch64GEMMLargeKScheduleWithEpilogueFactory::assignSchedule( b.create(loc); return success(); } +#endif // ENABLE_AARCH64_SCHEDUELS + +class CUDAMMAGEMMDefaultScheduleFactory : public ScheduleFactoryWithNoGuard { + public: + using ScheduleFactoryWithNoGuard::ScheduleFactoryWithNoGuard; + bool checkFusionPatternProperties(PatternDescription&) override; + LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) override; +}; + +bool CUDAMMAGEMMDefaultScheduleFactory::checkFusionPatternProperties( + PatternDescription& pd) { + if (!ScheduleFactory::checkFusionPatternProperties(pd)) { + return false; + } + auto& fusionPattern = pd.getFusionPattern(); + auto& rootOps = fusionPattern.getRootOps(); + // Only support single output a.t.m. + if (rootOps.size() != 1) { + return false; + } + + // This schedule not support epilogue fusion + auto dominantOp = fusionPattern.getDominantOp(); + return (rootOps[0] == dominantOp) && isa(dominantOp); +} + +// The schedule structure of TensorCore GEMM: +// +// parallel (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { +// parallel (int cta_m = 0; cta_m < GemmM; cta_m += CtaTileM) { +// for (int cta_k = 0; cta_k < GemmK; cta_k += CtaTileK) { +// // All the following loops should be fully unrolled. +// for (int warp_n = 0; warp_n < CtaTileN; warp_n += WarpTileN) { +// for (int warp_m = 0; warp_m < CtaTileM; warp_m += WarpTileM) { +// for (int warp_k = 0; warp_k < CtaTileK; warp_k += WarpTileK) { +// for (int mma_k = 0; mma_k < WarpTileK; mma_k += MmaK) { +// for (int mma_n = 0; mma_n < WarpTileN; mma_n += MmaN) { +// for (int mma_m = 0; mma_m < WarpTileM; mma_m += MmaM) { +// vector.contract(...) // lowered to MMA intrinsic. +// } // for mma_m +// } // for mma_n +// } // for mma_k +// } // for warp_k +// } // for warp_m +// } // for warp_n +// } // for cta_k +// } // parallel cta_m +// } // parallel cta_n +LogicalResult CUDAMMAGEMMDefaultScheduleFactory::assignSchedule( + PatternDescription& pd, ModuleOp m, DeviceInfo deviceInfo) { + OpBuilder b(m); + b.setInsertionPointToStart(&m.getBodyRegion().front()); + Location loc = m.getLoc(); + MLIRContext* ctx = m->getContext(); + auto transformOpType = transform::AnyOpType::get(ctx); + auto seqOp = b.create( + loc, TypeRange{}, transform::FailurePropagationMode::Propagate, + transformOpType, [&](OpBuilder& b, Location loc, Value variantH) {}); + auto& bodyBlock = seqOp.getBody().front(); + b.setInsertionPointToStart(&bodyBlock); + Value variant = bodyBlock.getArgument(0); + + auto& fusionPattern = pd.getFusionPattern(); + auto nameMap = TransformNameAssigner(fusionPattern.getOpList()).getNameMap(); + auto dotOp = + dyn_cast_or_null(fusionPattern.getDominantOp()); + if (!dotOp) { + return m->emitError() << "expect dot_general op as dominant\n"; + } + Value lhs = dotOp->getOperand(0); + Value rhs = dotOp->getOperand(1); + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + if (lhsTy.getRank() != 2 || rhsTy.getRank() != 2) { + return m->emitError() << "only support rank 2 GEMM a.t.m.\n"; + } + + auto dimNumbers = dotOp.getDotDimensionNumbers(); + auto lhsCntractingDims = dimNumbers.getLhsContractingDimensions(); + auto rhsCntractingDims = dimNumbers.getRhsContractingDimensions(); + if (lhsCntractingDims.size() != 1 || rhsCntractingDims.size() != 1) { + return m->emitError() << "only support exactly 1 contract dim\n"; + } + auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); + if (!lhsBatchingDims.empty()) { + return m->emitError() << "do not support batch matmul\n"; + } + bool lhsTranspose = (lhsCntractingDims[0] == lhsTy.getRank() - 2); + bool rhsTranspose = (rhsCntractingDims[0] == rhsTy.getRank() - 1); + if (lhsTranspose || rhsTranspose) { + return m->emitError() << "only support row-major matmul now\n"; + } + int64_t M = lhsTy.getShape()[lhsTy.getRank() - 2]; + int64_t K = lhsTy.getShape()[lhsTy.getRank() - 1]; + int64_t N = rhsTy.getShape()[rhsTy.getRank() - 1]; + + // build handle to target dot op. + Value fillAndMatmul = buildMatchOp(b, loc, variant, {}, nameMap[dotOp]); + auto matmulSplitOp = buildSplitHandleOp(b, loc, fillAndMatmul, 2); + Value fill = matmulSplitOp->getResult(0); + Value matmul = matmulSplitOp->getResult(1); + + // ========================== Multi-level tiling ========================== + + // Thread-block level tiling. Fixed tile size 128 x 128. + const SmallVector ctaTileSizes{128, 128, 32}; + SmallVector blockTileMapping{ + gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimX), + gpu::GPUBlockMappingAttr::get(ctx, gpu::Blocks::DimY)}; + auto blockTileMappingAttr = b.getArrayAttr(blockTileMapping); + + auto forallOpBlock = + buildTileToForallOp(b, loc, matmul, {ctaTileSizes[0], ctaTileSizes[1]}, + transform::TileSizesSpec(), blockTileMappingAttr); + Value forallLoopBlock = forallOpBlock->getResult(0); + Value tiledMatmulBlock = forallOpBlock->getResult(1); + + // Fuse fill op in to the forall loop. + auto fuseIntoContainingOp = + buildFuseIntoContainingOp(b, loc, fill, forallLoopBlock); + + // TODO: padding on block tile. + + // K iteration on block tile. + auto splitReductionSerialOpBlock = + buildSplitReductionSerialOp(b, loc, tiledMatmulBlock, {ctaTileSizes[2]}); + auto splitMatmulBlock = splitReductionSerialOpBlock->getResult(0); + + // Promote operands for shared memory buffering. + // TODO: promote operands for register buffering. + auto promoteDotOperandsOp = + buildPromoteDotOperandsOp(b, loc, splitMatmulBlock, {0, 1}); + auto promotedMatmul = promoteDotOperandsOp->getResult(0); + + // TODO: software pipelining on k iteration. + + // Warp tile. + const SmallVector warpTileSizes{64, 64, 32}; + SmallVector warpTileMapping{ + gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimX), + gpu::GPUWarpMappingAttr::get(ctx, gpu::Warps::DimY)}; + auto warpTileMappingAttr = b.getArrayAttr(warpTileMapping); + + auto forallOpWarp = buildTileToForallOp( + b, loc, promotedMatmul, {warpTileSizes[0], warpTileSizes[1]}, + transform::TileSizesSpec(), warpTileMappingAttr); + Value forallLoopWarp = forallOpWarp->getResult(0); + Value tiledMatmulWarp = forallOpWarp->getResult(1); + + // K iteration on warp tile. + auto splitReductionSerialOpWarp = + buildSplitReductionSerialOp(b, loc, tiledMatmulWarp, {warpTileSizes[2]}); + auto splitMatmulWarp = splitReductionSerialOpWarp->getResult(0); + + // Vector op tile. + SmallVector vectorTileSizes; + // The MMA instruction configuration for fp16. + if (deviceInfo.cc_major >= 8) { + vectorTileSizes = {16, 8, 16}; + } else if (deviceInfo.cc_major == 7 && deviceInfo.cc_minor == 5) { + vectorTileSizes = {16, 8, 8}; + } else if (deviceInfo.cc_major == 7 && deviceInfo.cc_minor == 0) { + vectorTileSizes = {8, 8, 4}; + } else { + return m->emitError() << "unsupported GPU compute capacity\n"; + } + auto tileOpVector = + buildTileOp(b, loc, splitMatmulWarp, vectorTileSizes, {0, 1, 2}); + + buildLICMOp(b, loc, variant); + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + // TODO: fully unroll the vector tiled loops. + + // ============================= Vectorization ============================= + + Value func4Vec = buildMatchOp(b, loc, variant, {"func.func"}); + func4Vec = buildRunCanonicalizer(b, loc, func4Vec); + auto vectorizeOp = buildVectorize(b, loc, func4Vec, true); + + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + // ============================= Bufferization ============================= + + variant = buildDISCBufferize(b, loc, variant, true); + Value funcAfterBufferize = buildMatchOp(b, loc, variant, {"func.func"}); + buildEraseDeallocOp(b, loc, funcAfterBufferize); + Value func2ConvertTransfer = buildMatchOp(b, loc, variant, {"func.func"}); + // TODO: init with 0 in parallel. + buildTransferWriteZeroToSCFOp(b, loc, func2ConvertTransfer); + + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + // ==================== ForallOp to GPU mappings ==================== + + auto blockTileMappingDictAttr = + b.getDictionaryAttr({b.getNamedAttr("mapping", blockTileMappingAttr)}); + Value forallBlock = buildMatchOp(b, loc, variant, {"scf.forall"}, {}, + blockTileMappingDictAttr); + auto parallelOp = buildForallToGPUCTAsOp(b, loc, forallBlock); + + auto warpTileMappingDictAttr = + b.getDictionaryAttr({b.getNamedAttr("mapping", warpTileMappingAttr)}); + Value forallWarp = buildMatchOp(b, loc, variant, {"scf.forall"}, {}, + warpTileMappingDictAttr); + buildForallToGPUWarpsOp(b, loc, forallWarp); + + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + // ======================== Gmem to Smem conversion ======================== + + Value genericOp = buildMatchOp(b, loc, variant, {"linalg.generic"}); + buildLowerGmemToSmemOp(b, loc, genericOp); + + // TODO: shared memory swizzle to avoid bank conflict. + + // ========================= Convert vector to mma ========================= + + Value func4MMA = buildMatchOp(b, loc, variant, {"func.func"}); + auto vectorToMMAConversionOp = buildVectorToMMAConversionOp(b, loc, func4MMA); + buildLICMOp(b, loc, variant); + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + // ============================ Post processing ============================ + + Value func4PostProcess = buildMatchOp(b, loc, variant, {"func.func"}); + buildInlineAndConvertGPUIdsOp(b, loc, func4PostProcess); + + buildLICMOp(b, loc, variant); + buildDCEOp(b, loc, variant); + buildCSEOp(b, loc, variant); + + b.create(loc); + + return success(); +} +#if ENABLE_AARCH64_SCHEDUELS DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, kDefaultScheduleFactoryPriority, Aarch64GEMMDefaultScheduleFactory, - ArrayRef{kDefaultScheduleFactoryTag}); + ArrayRef{kDefaultScheduleFactoryTag}, + DeviceType::kCPU); DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, 10, Aarch64GEMMDefaultScheduleWithEpilogueFactory, - ArrayRef{"default_epilogue"}); + ArrayRef{"default_epilogue"}, + DeviceType::kCPU); DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, 100, Aarch64GEMMLargeKScheduleFactory, - ArrayRef{"large_k"}); + ArrayRef{"large_k"}, DeviceType::kCPU); DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, 110, Aarch64GEMMLargeKScheduleWithEpilogueFactory, - ArrayRef{"large_k_epilogue"}); + ArrayRef{"large_k_epilogue"}, + DeviceType::kCPU); #endif // ENABLE_AARCH64_SCHEDUELS +// CUDA schedules +DISC_TRANSFORM_SCHEDULE(PatternKind::kGEMM, 1000, + CUDAMMAGEMMDefaultScheduleFactory, + ArrayRef{"cuda_mma_default"}, + DeviceType::kGPU); + } // namespace const char* kDefaultScheduleFactoryTag = "default"; @@ -1433,13 +1795,25 @@ PatternKind patternKindFromString(const std::string& str) { return PatternKind::kNone; } +DeviceType deviceTypeFromString(const std::string& str) { + auto& map = getStringToDeviceTypeMap(); + auto it = map.find(str); + if (it != map.end()) { + return it->second; + } + llvm_unreachable("unknown device type str"); + return DeviceType::kNone; +} + PatternDescription::PatternDescription(lmhlo::FusionOp op, FusionPattern& fusionPattern, ShapeAnalysis& shapeAnalysis) : op_(op), fusionPattern_(fusionPattern), shapeAnalysis_(shapeAnalysis), - tagSet_(parsefusionTagSetFromStr(getFusionTagStr(op))) { + tagSet_(parsefusionTagSetFromStr(getFusionTagStr(op))), + deviceType_(placement_utils::isGpuMhlo(op) ? DeviceType::kGPU + : DeviceType::kCPU) { // TODO(wyzero): select the pattern kind according to the `fusionPattern`. patternKind_ = PatternKind::kGEMM; } @@ -1454,13 +1828,18 @@ const std::set& PatternDescription::getPatternTagSet() const { return tagSet_; } +DeviceType PatternDescription::getPatternDeviceType() const { + return deviceType_; +} + std::string PatternDescription::getTaggedPatternStr() const { return patternKindToString(patternKind_) + "@" + getPatternTagStr(); } ScheduleFactory::ScheduleFactory(int64_t id, PatternKind kind, - ArrayRef tags) - : id_(id), kind_(kind) { + ArrayRef tags, + DeviceType deviceType) + : id_(id), kind_(kind), deviceType_(deviceType) { tagSet_.insert(Twine(id).str()); for (auto tag : tags) { tagSet_.insert(tag.str()); @@ -1478,7 +1857,11 @@ bool ScheduleFactory::checkKindAndTags(PatternDescription& pattern) { return true; } -bool ScheduleFactory::checkFusionPatternProperties(PatternDescription&) { +bool ScheduleFactory::checkFusionPatternProperties(PatternDescription& pd) { + // Check the device. + if (deviceType_ != pd.getPatternDeviceType()) { + return false; + } return true; } @@ -1488,7 +1871,8 @@ LogicalResult ScheduleFactory::buildGuardCondition(OpBuilder& b, Location loc, return failure(); } -LogicalResult ScheduleFactory::assignSchedule(PatternDescription&, ModuleOp) { +LogicalResult ScheduleFactory::assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo) { return failure(); } @@ -1532,7 +1916,9 @@ ScheduleFactoryRegistry::getAllCandidateScheduleFactories( if (it->second->accept(pd)) { factories.push_back(it->second.get()); // early stop - if (it->second->noGuardCondition(pd)) break; + if (it->second->noGuardCondition(pd)) { + break; + } } } return factories; @@ -1552,8 +1938,8 @@ ScheduleDispatcher::~ScheduleDispatcher() { LogicalResult ScheduleDispatcher::parseModuleFromFile(MLIRContext* ctx) { if (transformFileName_.empty() || !parsedModuleMap_.empty()) return success(); std::string expectedFormatStr = - "::;::<" - "filename-1>"; + ":::;" + ":::"; SmallVector patternSettings; StringRef(transformFileName_) .split(patternSettings, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); @@ -1561,7 +1947,7 @@ LogicalResult ScheduleDispatcher::parseModuleFromFile(MLIRContext* ctx) { for (auto& patternSetting : patternSettings) { SmallVector items; patternSetting.split(items, ":", /*MaxSplit=*/-1, /*KeepEmpty=*/true); - if (items.size() != 3) { + if (items.size() != 4) { llvm::dbgs() << "illegal transform file setting, expected format: " << expectedFormatStr << "\n"; return failure(); @@ -1574,10 +1960,17 @@ LogicalResult ScheduleDispatcher::parseModuleFromFile(MLIRContext* ctx) { } auto& transformModule = parsedModuleMap_[kind][items[1].str()]; - if (failed(parseTransformModuleFromFile(ctx, items[2], transformModule))) { + DeviceType deviceType = deviceTypeFromString(items[2].str()); + if (deviceType == DeviceType::kNone) { + llvm::dbgs() << "illegal transform file setting, unknown device type: " + << items[2] << "\n"; + return failure(); + } + + if (failed(parseTransformModuleFromFile(ctx, items[3], transformModule))) { llvm::dbgs() << "illegal transform file setting, unable to load module from: " - << items[2] << "\n"; + << items[3] << "\n"; return failure(); } SmallVector tags; @@ -1587,7 +1980,7 @@ LogicalResult ScheduleDispatcher::parseModuleFromFile(MLIRContext* ctx) { kind, priority++, std::make_unique( ScheduleFactoryRegistry::get().getNextUniqueId(), kind, tags, - transformModule.get())); + deviceType, transformModule.get())); } return success(); } @@ -1601,7 +1994,7 @@ LogicalResult ScheduleDispatcher::dispatch(PatternDescription& pd, ModuleOp m) { return failure(); } - return factory->assignSchedule(pd, m); + return factory->assignSchedule(pd, m, getDeviceInfo()); } } // namespace disc_ral diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h index 4431b450124..e29a0e2d4de 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.h @@ -33,6 +33,15 @@ namespace disc_ral { // schedules within the same category. enum class PatternKind : int32_t { kNone, kGEMM }; +enum class DeviceType { kCPU, kGPU, kNone }; + +struct DeviceInfo { + int cc_major = -1; + int cc_minor = -1; + int sm_count = -1; + int max_threads_per_sm = -1; +}; + // Converts a pattern kind to its string representation. std::string patternKindToString(PatternKind kind); @@ -57,6 +66,8 @@ class PatternDescription { const std::set& getPatternTagSet() const; + DeviceType getPatternDeviceType() const; + // Returns the fusion op this descriptor holds. lmhlo::FusionOp getFusionOp() { return op_; } @@ -72,6 +83,7 @@ class PatternDescription { ShapeAnalysis& shapeAnalysis_; PatternKind patternKind_; std::set tagSet_; + DeviceType deviceType_; }; // The name of the default schedule factory for a pattern kind. @@ -85,7 +97,7 @@ constexpr const int kParsedFromFileScheduleFactoryStartPriority = 10000; class ScheduleFactory { public: explicit ScheduleFactory(int64_t id, PatternKind kind, - ArrayRef tags); + ArrayRef tags, DeviceType deviceType); virtual ~ScheduleFactory() = default; // Returns true if the factory accepts the pattern at compile time. @@ -107,7 +119,8 @@ class ScheduleFactory { // Assign the transform schedule and attach it into the module op. // The pattern should be accepted by this factory and the guard condition // should be emitted before successfully. - virtual LogicalResult assignSchedule(PatternDescription&, ModuleOp); + virtual LogicalResult assignSchedule(PatternDescription&, ModuleOp, + DeviceInfo); // Returns the id this factory has. int64_t getId() { return id_; } @@ -118,8 +131,12 @@ class ScheduleFactory { // Returns the tag set this factory has. const std::set& getTagSet() { return tagSet_; } + // Returns the device type this factory corresponds to. + DeviceType getDeviceType() { return deviceType_; } + protected: - // these are called by `accept`. + // These are called by `accept`. No need to check device type as the kind and + // tags already determine a unique target. virtual bool checkKindAndTags(PatternDescription&); virtual bool checkFusionPatternProperties(PatternDescription&); @@ -127,6 +144,7 @@ class ScheduleFactory { int64_t id_; PatternKind kind_; std::set tagSet_; + DeviceType deviceType_; }; class ScheduleFactoryWithNoGuard : public ScheduleFactory { @@ -208,12 +226,16 @@ class ScheduleDispatcher { // Parses schedule modules from the given files. LogicalResult parseModuleFromFile(MLIRContext* ctx); + void setDeviceInfo(const DeviceInfo& deviceInfo) { deviceInfo_ = deviceInfo; } + const DeviceInfo& getDeviceInfo() { return deviceInfo_; } + private: std::string transformFileName_; // > std::unordered_map>> parsedModuleMap_; + DeviceInfo deviceInfo_; }; } // namespace disc_ral diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc index 64869489490..807cd6939ff 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -1289,7 +1289,7 @@ LogicalResult lowerWithScheduleRowReduction( root_ops.begin(), root_ops.end(), std::back_inserter(row_reduction_roots), [](Operation* operation) { return isRank2RowReduction(operation); }); - const int thread_per_block = getThreadPerBlock(dominant_op); + const int thread_per_block = getCTASize(dominant_op); Location loc = dominant_op->getLoc(); OpBuilder b(root_ops.back()); @@ -1720,7 +1720,7 @@ LogicalResult lowerWithScheduleRowReduction( Value shape_h = b.create(loc, lhs, zero); Value shape_w = b.create(loc, lhs, one); Value num_threads = - b.create(loc, getThreadPerBlock(dominant_op)); + b.create(loc, getCTASize(dominant_op)); std::map init_values_cache; SmallVector row_reduction_ops; SmallVector> shared_mem_map_vec(vector_size); @@ -1818,10 +1818,10 @@ LogicalResult lowerWithScheduleRowReduction( } } - Value lane_id_inbound = b.create( - loc, arith::CmpIPredicate::slt, lane_id, - b.create( - loc, getThreadPerBlock(dominant_op) / kWarpSize)); + Value lane_id_inbound = + b.create(loc, arith::CmpIPredicate::slt, lane_id, + b.create( + loc, getCTASize(dominant_op) / kWarpSize)); scf::IfOp if_lane_id_inbound = b.create(loc, /*resultTypes*/ root_elem_types, lane_id_inbound, /*hasElseRegion*/ true); @@ -1876,7 +1876,7 @@ LogicalResult lowerWithScheduleRowReduction( auto acc_iter = if_lane_id_inbound.getResults().begin(); if (failed(emitSecondRoundShuffle(b, loc, row_reduction_ops, acc_iter, thread_id_is_zero, row_ids, vector_size, - getThreadPerBlock(dominant_op)))) { + getCTASize(dominant_op)))) { return failure(); } @@ -3113,7 +3113,7 @@ LogicalResult lowerWithScheduleStitch(lmhlo::FusionOp& fusion_op, } } - const int thread_per_block = getThreadPerBlock(dominant_op); + const int thread_per_block = getCTASize(dominant_op); int reduce_threads = (row_reduction_schedule == DISC_BLOCK_WISE_ROW_REDUCE) ? thread_per_block : kWarpSize; @@ -3795,7 +3795,7 @@ LogicalResult lowerWithScheduleStitchV2(lmhlo::FusionOp& fusion_op, } } - const int thread_per_block = getThreadPerBlock(dominant_op); + const int thread_per_block = getCTASize(dominant_op); int reduce_threads = (row_reduction_schedule == DISC_BLOCK_WISE_ROW_REDUCE) ? thread_per_block : kWarpSize; diff --git a/tao_compiler/mlir/disc/transforms/parallel_loop_collapsing.cc b/tao_compiler/mlir/disc/transforms/parallel_loop_collapsing.cc index 7dfe48446d4..de1d4aa068c 100644 --- a/tao_compiler/mlir/disc/transforms/parallel_loop_collapsing.cc +++ b/tao_compiler/mlir/disc/transforms/parallel_loop_collapsing.cc @@ -47,7 +47,8 @@ struct ParallelLoopCollapsing if (fusion) { auto fusionTypeAttr = fusion->getAttrOfType(kDiscFusionTypeAttrName); - if (fusionTypeAttr && fusionTypeAttr.getValue() == "kStitch") { + if (fusionTypeAttr && (fusionTypeAttr.getValue() == "kStitch" || + fusionTypeAttr.getValue() == "kTransform")) { continue; } } diff --git a/tao_compiler/mlir/disc/transforms/parallel_loop_tiling.cc b/tao_compiler/mlir/disc/transforms/parallel_loop_tiling.cc index be8f6df7b80..551581be5a4 100644 --- a/tao_compiler/mlir/disc/transforms/parallel_loop_tiling.cc +++ b/tao_compiler/mlir/disc/transforms/parallel_loop_tiling.cc @@ -69,8 +69,7 @@ struct ParallelLoopTiling continue; } } - if (auto attr = - fusion->getAttrOfType(kThreadPerBlockHint)) { + if (auto attr = fusion->getAttrOfType(kCTASizeHint)) { localTileSizes = {attr.getInt()}; } } diff --git a/tao_compiler/mlir/disc/transforms/passes.h b/tao_compiler/mlir/disc/transforms/passes.h index 4fe82063f0d..8f2ce49d71e 100644 --- a/tao_compiler/mlir/disc/transforms/passes.h +++ b/tao_compiler/mlir/disc/transforms/passes.h @@ -314,6 +314,7 @@ std::unique_ptr> createDiscGPUSourceToLibPass( std::unique_ptr> createDiscTransformLegalizeToLoopPass(bool gpuEnabled = false, const std::string& filename = "", + int cc_major = 8, int cc_minor = 0, bool expensiveCheck = false); // Duplicate and fuse some computation into their fusion consumer to reduce @@ -321,6 +322,10 @@ createDiscTransformLegalizeToLoopPass(bool gpuEnabled = false, std::unique_ptr> createDiscDuplicateComputationAfterFusionPass(); +// Erase dealloc ops for GPU func ops. +std::unique_ptr> +createDiscEraseBufferDeallocationPass(); + } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/tests/gpu-only-lhlo-legalize-roots-to-loops.mlir b/tao_compiler/mlir/disc/transforms/tests/gpu-only-lhlo-legalize-roots-to-loops.mlir index 50cb83f275d..590517538e5 100755 --- a/tao_compiler/mlir/disc/transforms/tests/gpu-only-lhlo-legalize-roots-to-loops.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/gpu-only-lhlo-legalize-roots-to-loops.mlir @@ -161,7 +161,7 @@ func.func @multi_row_reduce(%arg_f32: memref<2048x768xf32>, %init_f32: memref () }) {dimensions = dense<1> : tensor<1xi64>} : (memref<2048x768xf64>, memref, memref<2048xf64>) -> () "lmhlo.terminator"() : () -> () - }) {disc.fusion.name = "main_kRowReduction_reduce_reduce", disc.device = "gpu", disc.fusion.tag = "1b1r", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_thread_per_block_hint = 256 : i32} : () -> () + }) {disc.fusion.name = "main_kRowReduction_reduce_reduce", disc.device = "gpu", disc.fusion.tag = "1b1r", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_cta_size_hint = 256 : i32} : () -> () return %out_f32, %out_f64: memref<2048xf32>, memref<2048xf64> } diff --git a/tao_compiler/mlir/disc/transforms/tests/parallel-loop-tiling-inbound-check.mlir b/tao_compiler/mlir/disc/transforms/tests/parallel-loop-tiling-inbound-check.mlir index 137a82c91d2..9fec792af88 100644 --- a/tao_compiler/mlir/disc/transforms/tests/parallel-loop-tiling-inbound-check.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/parallel-loop-tiling-inbound-check.mlir @@ -162,7 +162,7 @@ func.func @parallel_loop_with_hint(%pred : i1, memref.store %sum_elem, %result[%i0, %i1] : memref } "lmhlo.terminator"() : () -> () - } ) { disc_thread_per_block_hint = 256 : i32 } : () -> () + } ) { disc_cta_size_hint = 256 : i32 } : () -> () } else { "lmhlo.fusion"() ({ scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { @@ -172,7 +172,7 @@ func.func @parallel_loop_with_hint(%pred : i1, memref.store %sum_elem, %result[%i0, %i1] : memref } "lmhlo.terminator"() : () -> () - } ) { disc_thread_per_block_hint = 64 : i32 } : () -> () + } ) { disc_cta_size_hint = 64 : i32 } : () -> () } return }