diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 82533a2f9f5a..fc65ba033adf 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -387,6 +387,17 @@ def LowerCustomDatatypes(): return _ffi_api.LowerCustomDatatypes() # type: ignore +def HoistWorkspaceAllocation(): + """Hoist workspace Buffer allocation into function signature + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return HoistWorkspaceAllocation() # type: ignore + + def MakePackedAPI(): """Transform the PrimFuncs in the module to a packed func API. diff --git a/src/tir/transforms/hoist_buffer_allocation.cc b/src/tir/transforms/hoist_buffer_allocation.cc new file mode 100644 index 000000000000..997b2c4f8fbb --- /dev/null +++ b/src/tir/transforms/hoist_buffer_allocation.cc @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tir/transforms/hoist_buffer_allocation.cc + * \brief Pass for hoisting buffer allocation into function signature + */ + +#include +#include +#include +#include +#include + +#include + +#include "../../target/datatype/registry.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class HoistWorkspaceAlloc : public StmtExprMutator { + public: + explicit HoistWorkspaceAlloc(IRModule mod) : mod_(std::move(mod)) {} + + IRModule operator()() { + auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main); + auto base_func = mod_->Lookup(main_func_gv); + auto main_func = runtime::Downcast(base_func); + + Stmt new_body = VisitStmt(main_func->body); + + auto input_vars_optional = main_func->GetAttr>("input_vars"); + ICHECK(input_vars_optional.defined()) << "Input vars are undefined"; + auto output_vars_optional = main_func->GetAttr>("output_vars"); + ICHECK(output_vars_optional.defined()) << "Ouput vars are undefined"; + + // 1. Insert the input vars in the new_buffer_map + Map new_buffer_map; + for (auto var : input_vars_optional.value()) { + auto buffer = main_func->buffer_map.Get(var); + if (buffer.defined()) { + new_buffer_map.Set(var, buffer.value()); + } + } + + // 2. Construct the new params of the function and insert the new values in the new_buffer_map + Array new_params = Array(input_vars_optional.value()); + for (auto it : buffer_map_to_append) { + new_params.push_back(it.first); + new_buffer_map.Set(it.first, it.second); + } + new_params.insert(new_params.end(), output_vars_optional.value().begin(), + output_vars_optional.value().end()); + + // 3. Finish constructing the new_buffer_map by inserting the output vars. + for (auto var : output_vars_optional.value()) { + auto buffer = main_func->buffer_map.Get(var); + if (buffer.defined()) { + new_buffer_map.Set(var, buffer.value()); + } + } + + PrimFunc new_func = PrimFunc(new_params, new_body, main_func->ret_type, new_buffer_map, + main_func->attrs, main_func->span); + + mod_->Update(main_func_gv, new_func); + return mod_; + } + + private: + Stmt VisitStmt_(const AllocateNode* op) final { + // Remove the allocate node if the storage scope is defined + String storage_scope = GetPtrStorageScope(op->buffer_var); + if (storage_scope.defined() && !storage_scope.empty()) { + return VisitStmt(op->body); + } + + return VisitStmt_(op); + } + + Stmt VisitStmt_(const DeclBufferNode* op) final { + // Remove buffer decl node if it has a valid storage scope and register the + // binding in the buffer_map_to_append. + String storage_scope = GetPtrStorageScope(op->buffer->data); + if (storage_scope.defined() && !storage_scope.empty()) { + buffer_map_to_append.Set(op->buffer->data, op->buffer); + return VisitStmt(op->body); + } + return VisitStmt_(op); + } + + IRModule mod_; + Map buffer_map_to_append; +}; + +namespace transform { + +Pass HoistWorkspaceAllocation() { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return runtime::Downcast(tvm::tir::HoistWorkspaceAlloc(m)()); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.HoistWorkspaceAllocation", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistWorkspaceAllocation").set_body_typed(HoistWorkspaceAllocation); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_hoist_alloc.py b/tests/python/unittest/test_tir_transform_hoist_alloc.py new file mode 100644 index 000000000000..9b8c72238a7e --- /dev/null +++ b/tests/python/unittest/test_tir_transform_hoist_alloc.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm.script import tir as T + +# fmt: off +@tvm.script.ir_module +class SimpleGraph: + @T.prim_func + def __tvm_main__(a: T.handle, output: T.handle): + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target( + {"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [T.int64(5), T.int64(7)], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [T.int64(5), T.int64(7)], dtype="float32", align=16) + # body + sid_0 = T.decl_buffer([140], dtype="uint8", strides=[1], scope="global.workspace", align=16) + tid_0: T.Ptr[T.float32, "global.workspace"] = T.address_of(sid_0[0], dtype="handle") + + +@tvm.script.ir_module +class PostHoistGraph: + @T.prim_func + def __tvm_main__(a: T.handle, sid_0_1: T.Ptr[T.uint8], output: T.handle): + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [T.int64(5), T.int64(7)], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [T.int64(5), T.int64(7)], dtype="float32", align=16) + # body + sid_0 = T.match_buffer(sid_0_1, [140], dtype="uint8", strides=[1], elem_offset=0, align=16) + tid_0: T.Ptr[T.float32, "global.workspace"] = T.address_of(sid_0[0], dtype="handle") +# fmt: on + + +def test_simple_graph_one_pool(): + tir_mod = SimpleGraph + + tir_post_hoist = tvm.tir.transform.HoistWorkspaceAllocation()(tir_mod) + + expected_mod = PostHoistGraph + + tvm.ir.structural_equal(tir_post_hoist, expected_mod) + + + +# fmt: off +@tvm.script.ir_module +class SimpleGraphMultiplePools: + @T.prim_func + def __tvm_main__(a: T.handle, b: T.handle, output: T.handle): + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target( + {"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a, b], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [T.int64(5), T.int64(7)], dtype="float32", align=16) + b_buffer = T.match_buffer(b, [T.int64(5), T.int64(7)], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [T.int64(5), T.int64(7)], dtype="float32", align=16) + # body + sid_0 = T.decl_buffer([140], dtype="uint8", strides=[1], scope="global.workspace", align=16) + sid_1 = T.decl_buffer([256], dtype="uint8", strides=[1], scope="vtcm.workspace", align=16) + tid_0: T.Ptr[T.float32, "global.workspace"] = T.address_of(sid_0[0], dtype="handle") + tid_1: T.Ptr[T.float32, "vtcm.workspace"] = T.address_of(sid_1[16], dtype="handle") + + +@tvm.script.ir_module +class PostHoistGraphMultiplePools: + @T.prim_func + def __tvm_main__(a: T.handle, b: T.handle, sid_0_1: T.Ptr[T.uint8], sid_1_1: T.Ptr[T.uint8], output: T.handle): + # function attr dict + T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a, b], "output_vars": [output]}) + a_buffer = T.match_buffer(a, [T.int64(5), T.int64(7)], dtype="float32", align=16) + b_buffer = T.match_buffer(b, [T.int64(5), T.int64(7)], dtype="float32", align=16) + output_buffer = T.match_buffer(output, [T.int64(5), T.int64(7)], dtype="float32", align=16) + # body + sid_0 = T.match_buffer(sid_0_1, [140], dtype="uint8", strides=[1], elem_offset=0, align=16) + sid_1 = T.match_buffer(sid_1_1, [256], dtype="uint8", strides=[1], elem_offset=0, align=16) + tid_0: T.Ptr[T.float32, "global.workspace"] = T.address_of(sid_0[0], dtype="handle") + tid_1: T.Ptr[T.float32, "vtcm.workspace"] = T.address_of(sid_1[16], dtype="handle") +# fmt: on + + +def test_simple_graph_multiple_pools(): + tir_mod = SimpleGraphMultiplePools + + tir_post_hoist = tvm.tir.transform.HoistWorkspaceAllocation()(tir_mod) + + expected_mod = PostHoistGraphMultiplePools + + tvm.ir.structural_equal(tir_post_hoist, expected_mod) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:])