diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 6d04a43ce6..dbbbb0e1b5 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -6,6 +6,8 @@ on: - master pull_request: +permissions: read-all + env: PACKAGE_NAME: dpctl MODULE_NAME: dpctl @@ -20,7 +22,7 @@ jobs: matrix: python: ['3.9', '3.10', '3.11'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -28,7 +30,7 @@ jobs: run: | echo "pkgs_dirs: [~/.conda/pkgs]" >> ~/.condarc - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -58,12 +60,12 @@ jobs: $CHANNELS \ conda-recipe - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} path: /usr/share/miniconda/conda-bld/linux-64/${{ env.PACKAGE_NAME }}-*.tar.bz2 - name: Upload wheels artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Wheels Python ${{ matrix.python }} path: ${{ env.WHEELS_OUTPUT_FOLDER }}${{ env.PACKAGE_NAME }}-*.whl @@ -77,10 +79,10 @@ jobs: env: conda-bld: C:\Miniconda\conda-bld\win-64\ steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: conda-incubator/setup-miniconda@v2 + - uses: conda-incubator/setup-miniconda@v3 with: auto-activate-base: true conda-build-version: "*" @@ -88,7 +90,7 @@ jobs: python-version: ${{ matrix.python }} - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -107,12 +109,12 @@ jobs: OVERRIDE_INTEL_IPO: 1 # IPO requires more resources that GH actions VM provides run: conda build --no-test --python ${{ matrix.python }} -c intel -c conda-forge --override-channels conda-recipe - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} path: ${{ env.conda-bld }}${{ env.PACKAGE_NAME }}-*.tar.bz2 - name: Upload wheels artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Wheels Python ${{ matrix.python }} path: ${{ env.WHEELS_OUTPUT_FOLDER }}${{ env.PACKAGE_NAME }}-*.whl @@ -132,7 +134,7 @@ jobs: steps: - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - name: Add conda to system path @@ -159,7 +161,7 @@ jobs: run: | echo "pkgs_dirs: [~/.conda/pkgs]" >> ~/.condarc - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -217,10 +219,10 @@ jobs: steps: - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - - uses: conda-incubator/setup-miniconda@v2 + - uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true conda-build-version: '*' @@ -260,7 +262,7 @@ jobs: shell: pwsh run: Get-Content -Path .\lockfile - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -324,12 +326,12 @@ jobs: python: ['3.9', '3.10', '3.11'] steps: - name: Download conda artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - name: Download wheel artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Wheels Python ${{ matrix.python }} @@ -360,16 +362,16 @@ jobs: python: ['3.9', '3.10', '3.11'] steps: - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - name: Download wheel artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Wheels Python ${{ matrix.python }} - - uses: conda-incubator/setup-miniconda@v2 + - uses: conda-incubator/setup-miniconda@v3 with: auto-activate-base: true activate-environment: "" @@ -409,11 +411,11 @@ jobs: # Needed to be able to run conda index run: conda install conda-build python=${{ matrix.python }} - name: Checkout dpctl repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - name: Add conda to system path @@ -435,7 +437,7 @@ jobs: run: | echo "pkgs_dirs: [~/.conda/pkgs]" >> ~/.condarc - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -539,6 +541,8 @@ jobs: array-api-conformity: needs: build_linux runs-on: ${{ matrix.runner }} + permissions: + pull-requests: write strategy: matrix: @@ -550,12 +554,12 @@ jobs: CHANNELS: -c intel -c conda-forge --override-channels steps: - name: Checkout dpctl repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Cache array API tests id: cache-array-api-tests - uses: actions/cache@v3 + uses: actions/cache@v4 env: ARRAY_CACHE: 3 with: @@ -574,7 +578,7 @@ jobs: git clone --recurse-submodules https://github.com/data-apis/array-api-tests array-api-tests cd array-api-tests - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }} - name: Add conda to system path @@ -601,7 +605,7 @@ jobs: run: | echo "pkgs_dirs: [~/.conda/pkgs]" >> ~/.condarc - name: Cache conda packages - uses: actions/cache@v3 + uses: actions/cache@v4 env: CACHE_NUMBER: 3 # Increase to reset cache with: @@ -642,7 +646,7 @@ jobs: python -c "import dpctl; dpctl.lsplatform()" export ARRAY_API_TESTS_MODULE=dpctl.tensor cd /home/runner/work/array-api-tests - pytest --json-report --json-report-file=$FILE --skips-file ${GITHUB_WORKSPACE}/.github/workflows/array-api-skips.txt array_api_tests/ || true + pytest --json-report --json-report-file=$FILE --disable-deadline --skips-file ${GITHUB_WORKSPACE}/.github/workflows/array-api-skips.txt array_api_tests/ || true - name: Set Github environment variables shell: bash -l {0} run: | @@ -668,7 +672,7 @@ jobs: run: echo "::notice ${{ env.MESSAGE }}" - name: Post result to PR if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork }} - uses: mshick/add-pr-comment@v1 + uses: mshick/add-pr-comment@v2 with: message: | ${{ env.MESSAGE }} @@ -684,7 +688,7 @@ jobs: run: shell: bash -el {0} steps: - - uses: conda-incubator/setup-miniconda@v2 + - uses: conda-incubator/setup-miniconda@v3 with: run-post: false channel-priority: "disabled" @@ -695,7 +699,7 @@ jobs: run: conda install anaconda-client - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: IntelPython/devops-tools fetch-depth: 0 diff --git a/.github/workflows/cpp_style_checks.yml b/.github/workflows/cpp_style_checks.yml index a450bff627..facf85651b 100644 --- a/.github/workflows/cpp_style_checks.yml +++ b/.github/workflows/cpp_style_checks.yml @@ -9,19 +9,21 @@ on: push: branches: [master] +permissions: read-all + jobs: formatting-check: name: clang-format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Run clang-format style check for C/C++ programs. - uses: jidicula/clang-format-action@v3.5.1 + uses: jidicula/clang-format-action@v4.11.0 with: clang-format-version: '11' check-path: 'libsyclinterface' - name: Run clang-format style check for api headers. - uses: jidicula/clang-format-action@v3.5.1 + uses: jidicula/clang-format-action@v4.11.0 with: clang-format-version: '11' check-path: 'dpctl/apis' diff --git a/.github/workflows/generate-coverage.yaml b/.github/workflows/generate-coverage.yaml index edf03bc8f6..7ec430331f 100644 --- a/.github/workflows/generate-coverage.yaml +++ b/.github/workflows/generate-coverage.yaml @@ -4,10 +4,14 @@ on: push: branches: [master] +permissions: read-all + jobs: generate-coverage: name: Generate coverage and push to Coveralls.io - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest + permissions: + pull-requests: write env: ONEAPI_ROOT: /opt/intel/oneapi @@ -17,7 +21,7 @@ jobs: steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.11.0 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} @@ -46,14 +50,14 @@ jobs: sudo apt-get install ninja-build - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' architecture: x64 - name: Cache Gtest id: cache-gtest - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | /home/runner/work/googletest-1.13.0/install @@ -77,7 +81,7 @@ jobs: make && make install - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/generate-docs.yml b/.github/workflows/generate-docs.yml index 84bbed4622..1faa18713c 100644 --- a/.github/workflows/generate-docs.yml +++ b/.github/workflows/generate-docs.yml @@ -6,13 +6,18 @@ on: pull_request: types: [opened, synchronize, reopened, closed] +permissions: read-all + jobs: build-and-deploy: name: Build and Deploy Documentation - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.11.0 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - name: Add Intel repository @@ -41,7 +46,7 @@ jobs: sudo apt-get install ninja-build - name: Setup Python if: ${{ !github.event.pull_request || github.event.action != 'closed' }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' architecture: x64 @@ -51,7 +56,7 @@ jobs: run: | pip install numpy cython setuptools scikit-build cmake sphinx"<7.2" sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 persist-credentials: false @@ -76,7 +81,7 @@ jobs: mv ../cmake-install/docs/docs ~/docs git clean -dfx - name: Publish docs - if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork && github.ref == 'refs/heads/master' }} + if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork && github.ref == 'refs/heads/master' && github.event.action != 'closed' }} shell: bash -l {0} run: | git remote add tokened_docs https://IntelPython:${{ secrets.GITHUB_TOKEN }}@github.com/IntelPython/dpctl.git @@ -93,7 +98,7 @@ jobs: git push tokened_docs gh-pages - name: Save built docs as an artifact if: ${{ github.event.pull_request && github.event.pull_request.head.repo.fork && github.event.action != 'closed'}} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ env.PACKAGE_NAME }} rendered documentation path: ~/docs @@ -138,7 +143,7 @@ jobs: if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork && github.event.action != 'closed' }} env: PR_NUM: ${{ github.event.number }} - uses: mshick/add-pr-comment@v1 + uses: mshick/add-pr-comment@v2 with: message: | View rendered docs @ https://intelpython.github.io/dpctl/pulls/${{ env.PR_NUM }}/index.html @@ -148,7 +153,7 @@ jobs: if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork && github.event.action == 'closed' }} env: PR_NUM: ${{ github.event.number }} - uses: mshick/add-pr-comment@v1 + uses: mshick/add-pr-comment@v2 with: message: | Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. :crossed_fingers: diff --git a/.github/workflows/openssf-scorecard.yml b/.github/workflows/openssf-scorecard.yml new file mode 100644 index 0000000000..fbd16a4f28 --- /dev/null +++ b/.github/workflows/openssf-scorecard.yml @@ -0,0 +1,73 @@ +# This workflow uses actions that are not certified by GitHub. They are provided +# by a third-party and are governed by separate terms of service, privacy +# policy, and support documentation. + +name: Scorecard supply-chain security +on: + # For Branch-Protection check. Only the default branch is supported. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection + branch_protection_rule: + # To guarantee Maintained check is occasionally updated. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained + schedule: + - cron: '28 2 * * 1' + - cron: '28 2 * * 4' + push: + branches: [ "master" ] + +# Declare default permissions as read only. +permissions: read-all + +jobs: + analysis: + name: Scorecard analysis + runs-on: ubuntu-latest + permissions: + # Needed to upload the results to code-scanning dashboard. + security-events: write + # Needed to publish results and get a badge (see publish_results below). + id-token: write + # Uncomment the permissions below if installing in a private repository. + # contents: read + # actions: read + + steps: + - name: "Checkout code" + uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0 + with: + persist-credentials: false + + - name: "Run analysis" + uses: ossf/scorecard-action@e38b1902ae4f44df626f11ba0734b14fb91f8f86 # v2.1.2 + with: + results_file: results.sarif + results_format: sarif + # (Optional) "write" PAT token. Uncomment the `repo_token` line below if: + # - you want to enable the Branch-Protection check on a *public* repository, or + # - you are installing Scorecard on a *private* repository + # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat. + # repo_token: ${{ secrets.SCORECARD_TOKEN }} + + # Public repositories: + # - Publish results to OpenSSF REST API for easy access by consumers + # - Allows the repository to include the Scorecard badge. + # - See https://github.com/ossf/scorecard-action#publishing-results. + # For private repositories: + # - `publish_results` will always be set to `false`, regardless + # of the value entered here. + publish_results: true + + # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF + # format to the repository Actions tab. + - name: "Upload artifact" + uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v3.1.0 + with: + name: SARIF file + path: results.sarif + retention-days: 14 + + # Upload the results to GitHub's code scanning dashboard. + - name: "Upload to code-scanning" + uses: github/codeql-action/upload-sarif@17573ee1cc1b9d061760f3a006fc4aac4f944fd5 # v2.2.4 + with: + sarif_file: results.sarif diff --git a/.github/workflows/os-llvm-sycl-build.yml b/.github/workflows/os-llvm-sycl-build.yml index 3731a3fb77..78e825c9bf 100644 --- a/.github/workflows/os-llvm-sycl-build.yml +++ b/.github/workflows/os-llvm-sycl-build.yml @@ -4,6 +4,8 @@ on: push: branches: [master] +permissions: read-all + jobs: install-compiler: name: Build with nightly build of DPC++ toolchain @@ -20,13 +22,13 @@ jobs: steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.11.0 + uses: styfle/cancel-workflow-action@0.12.1 with: access_token: ${{ github.token }} - name: Cache sycl bundle id: cache-sycl-bundle - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | /home/runner/work/sycl_bundle @@ -100,7 +102,7 @@ jobs: sudo apt-get install libtinfo5 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' architecture: x64 @@ -111,7 +113,7 @@ jobs: pip install numpy"<1.26.0" cython setuptools pytest scikit-build cmake ninja - name: Checkout repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index f7d799463d..c9925da2df 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,12 +5,14 @@ on: push: branches: [master] +permissions: read-all + jobs: pre-commit: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.10' - name: Version of clang-format diff --git a/.github/workflows/python_style_checks.yml b/.github/workflows/python_style_checks.yml index 3afd5acbd9..9059e90aec 100644 --- a/.github/workflows/python_style_checks.yml +++ b/.github/workflows/python_style_checks.yml @@ -9,16 +9,18 @@ on: push: branches: [master] +permissions: read-all + # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: # The isort job sorts all imports in .py, .pyx, .pxd files isort: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - uses: jamescurtin/isort-action@master with: configuration: "--check-only" @@ -30,11 +32,11 @@ jobs: # Steps represent a sequence of tasks that will be executed as part of the job steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 # Set up a Python environment for use in actions - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' # Run black code formatter - uses: psf/black@stable @@ -47,11 +49,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/CHANGELOG.md b/CHANGELOG.md index fc5bb4db36..d4da4189ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,39 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.15.0] +## [0.16.0] - MMM. DD, 2024 + +This release reaches milestone of 100% compliance of `dpctl.tensor` functions with Python Array API 2022.12 standard for the main namespace. + +### Added + +* Added reduction functions `dpctl.tensor.min`, `dpctl.tensor.max`, `dpctl.tensor.argmin`, `dpctl.tensor.argmax`, and `dpctl.tensor.prod` per Python Array API specifications: [#1399](https://github.com/IntelPython/dpctl/pull/1399) +* Added dedicated in-place operations for binary elementwise operations and deployed them in Python operators of `dpctl.tensor.usm_ndarray` type: [#1431](https://github.com/IntelPython/dpctl/pull/1431), [#1447](https://github.com/IntelPython/dpctl/pull/1447) +* Added new elementwise functions `dpctl.tensor.cbrt`, `dpctl.tensor.rsqrt`, `dpctl.tensor.exp2`, `dpctl.tensor.copysign`, `dpctl.tensor.angle`, and `dpctl.tensor.reciprocal`: [#1443](https://github.com/IntelPython/dpctl/pull/1443), [#1474](https://github.com/IntelPython/dpctl/pull/1474) +* Added statistical functions `dpctl.tensor.mean`, `dpctl.tensor.std`, `dpctl.tensor.var` per Python Array API specifications: [#1465](https://github.com/IntelPython/dpctl/pull/1465) +* Added sorting functions `dpctl.tensor.sort` and `dpctl.tensor.argsort`, and set functions `dpctl.tensor.unique_values`, `dpctl.tensor.unique_counts`, `dpctl.tensor.unique_inverse`, `dpctl.tensor.unique_all`: [#1483](https://github.com/IntelPython/dpctl/pull/1483) +* Added linear algebra functions from the Array API namespace `dpctl.tensor.matrix_transpose`, `dpctl.tensor.matmul`, `dpctl.tensor.vecdot`, and `dpctl.tensor.tensordot`: [#1490](https://github.com/IntelPython/dpctl/pull/1490) +* Added `dpctl.tensor.clip` function: [#1444](https://github.com/IntelPython/dpctl/pull/1444), [#1505](https://github.com/IntelPython/dpctl/pull/1505) +* Added custom reduction functions `dpt.logsumexp` (reduction using binary function `dpctl.tensor.logaddexp`), `dpt.reduce_hypot` (reduction using binary function `dpctl.tensor.hypot`): [#1446](https://github.com/IntelPython/dpctl/pull/1446) +* Added inspection API to query capabilities of Python Array API specification implementation: [#1469](https://github.com/IntelPython/dpctl/pull/1469) +* Support for compilation for NVIDIA(R) sycl target with use of [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/): [#1411](https://github.com/IntelPython/dpctl/pull/1411), [#1124](https://github.com/IntelPython/dpctl/discussions/1124) +* Added `dpctl.utils.intel_device_info` function to query additional information about Intel(R) GPU devices: [gh-1428](https://github.com/IntelPython/dpctl/pull/1428) and [gh-1445](https://github.com/IntelPython/dpctl/pull/1445) + +### Changed + +* Functions `dpctl.tensor.result_type` and `dpctl.tensor.can_cast` became device-aware: [#1488](https://github.com/IntelPython/dpctl/pull/1488), [#1473](https://github.com/IntelPython/dpctl/pull/1473) +* Implementation of method `dpctl.SyclEvent.wait_for` changed to use ``sycl::event::wait`` instead of ``sycl::event::wait_and_throw``: [gh-1436](https://github.com/IntelPython/dpctl/pull/1436) +* `dpctl.tensor.astype` was changed to support `device` keyword as per Python Array API specification: [#1511](https://github.com/IntelPython/dpctl/pull/1511) +* C++ header files in `libtensor/include/kernels` containing implementations of SYCL kernels no longer depends on "pybind11.h": [#1516](https://github.com/IntelPython/dpctl/pull/1516) + +### Fixed + +* Fixed issues with `dpctl.tensor.repeat` support for `axis` keyword: [#1427](https://github.com/IntelPython/dpctl/pull/1427), [#1433](https://github.com/IntelPython/dpctl/pull/1433) +* Fix for gh-1503 for bug `usm_ndarray.__setitem__`: [#1504](https://github.com/IntelPython/dpctl/pull/1504) +* Other bug fixes: [#1485](https://github.com/IntelPython/dpctl/pull/1485), [#1477](https://github.com/IntelPython/dpctl/pull/1477), [#1512](https://github.com/IntelPython/dpctl/pull/1512) + + +## [0.15.0] - Sep. 29, 2023 ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index 7688ff040c..eb1346a423 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,10 @@ file(GLOB _cmake_scripts ${CMAKE_SOURCE_DIR}/cmake/*.cmake) install(FILES ${_cmake_scripts} DESTINATION dpctl/resources/cmake ) +install(FILES + ${CMAKE_SOURCE_DIR}/cmake/dpctl-config.cmake + DESTINATION lib/cmake/dpctl +) if (DPCTL_GENERATE_DOCS) add_subdirectory(docs) diff --git a/README.md b/README.md index d26b4c97af..19d2eca840 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Coverage Status](https://coveralls.io/repos/github/IntelPython/dpctl/badge.svg?branch=master)](https://coveralls.io/github/IntelPython/dpctl?branch=master) ![Generate Documentation](https://github.com/IntelPython/dpctl/actions/workflows/generate-docs.yml/badge.svg?branch=master) [![Join the chat at https://matrix.to/#/#Data-Parallel-Python_community:gitter.im](https://badges.gitter.im/Join%20Chat.svg)](https://app.gitter.im/#/room/#Data-Parallel-Python_community:gitter.im) +[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/IntelPython/dpctl/badge)](https://securityscorecards.dev/viewer/?uri=github.com/IntelPython/dpctl) oneAPI logo diff --git a/cmake/FindDpctl.cmake b/cmake/dpctl-config.cmake similarity index 83% rename from cmake/FindDpctl.cmake rename to cmake/dpctl-config.cmake index 149c75bd51..fa3f136b47 100644 --- a/cmake/FindDpctl.cmake +++ b/cmake/dpctl-config.cmake @@ -6,14 +6,17 @@ # # ``Dpctl_FOUND`` # True if DPCTL was found. -# ``Dpctl_INCLUDE_DIRS`` -# The include directories needed to use Dpctl. +# ``Dpctl_INCLUDE_DIR`` +# The include directory needed to use dpctl. +# ``Dpctl_TENSOR_INCLUDE_DIR`` +# The include directory for tensor kernels implementation. # ``Dpctl_VERSION`` -# The version of DPCTL found. +# The version of dpctl found. # -# The module will also explicitly define one cache variable: +# The module will also explicitly define two cache variables: # # ``Dpctl_INCLUDE_DIR`` +# ``Dpctl_TENSOR_INCLUDE_DIR`` # if(NOT Dpctl_FOUND) @@ -22,7 +25,7 @@ if(NOT Dpctl_FOUND) if(Python_EXECUTABLE) execute_process(COMMAND "${Python_EXECUTABLE}" - -c "import dpctl; print(dpctl.get_include())" + -m dpctl --include-dir OUTPUT_VARIABLE _dpctl_include_dir OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index b5c6f9a7e1..c99bdb9545 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -31,7 +31,7 @@ requirements: run: - python - dpcpp-cpp-rt ={{ required_compiler_version }} - - {{ pin_compatible('numpy', min_pin='x.x', upper_bound='1.26') }} + - {{ pin_compatible('numpy', min_pin='x.x', max_pin='x') }} - level-zero # [linux] test: diff --git a/dpctl/__main__.py b/dpctl/__main__.py index 78c5f7fde0..9b51d74903 100644 --- a/dpctl/__main__.py +++ b/dpctl/__main__.py @@ -15,42 +15,57 @@ # limitations under the License. import argparse +import importlib import os import os.path import platform import sys import warnings -import dpctl - def _dpctl_dir() -> str: - abs_path = os.path.abspath(dpctl.__file__) - dpctl_dir = os.path.dirname(abs_path) - return dpctl_dir + dpctl_dir = importlib.util.find_spec("dpctl").submodule_search_locations[0] + abs_dpctl_dir = os.path.abspath(dpctl_dir) + return abs_dpctl_dir -def print_includes() -> None: +def get_include_dir() -> str: "Prints include flags for dpctl and SyclInterface library" - print("-I " + dpctl.get_include()) + return os.path.join(_dpctl_dir(), "include") -def print_tensor_includes() -> None: +def print_include_flags() -> None: "Prints include flags for dpctl and SyclInterface library" + print("-I " + get_include_dir()) + + +def get_tensor_include_dir() -> str: dpctl_dir = _dpctl_dir() libtensor_dir = os.path.join(dpctl_dir, "tensor", "libtensor", "include") + return libtensor_dir + + +def print_tensor_include_flags() -> None: + "Prints include flags for dpctl and SyclInterface library" + libtensor_dir = get_tensor_include_dir() print("-I " + libtensor_dir) def print_cmake_dir() -> None: "Prints directory with FindDpctl.cmake" dpctl_dir = _dpctl_dir() - print(os.path.join(dpctl_dir, "resources", "cmake")) + cmake_dir = os.path.join(dpctl_dir, "resources", "cmake") + print(cmake_dir) + + +def get_library_dir() -> str: + dpctl_dir = _dpctl_dir() + return dpctl_dir def print_library() -> None: "Prints linker flags for SyclInterface library" - dpctl_dir = _dpctl_dir() + dpctl_dir = get_library_dir() plt = platform.platform() ld_flags = "-L " + dpctl_dir if plt != "Windows": @@ -73,6 +88,8 @@ def _warn_if_any_set(args, li) -> None: def print_lsplatform(verbosity: int) -> None: + import dpctl + dpctl.lsplatform(verbosity=verbosity) @@ -84,11 +101,21 @@ def main() -> None: action="store_true", help="Include flags for dpctl headers.", ) + parser.add_argument( + "--include-dir", + action="store_true", + help="Path to dpctl include directory.", + ) parser.add_argument( "--tensor-includes", action="store_true", help="Include flags for dpctl libtensor headers.", ) + parser.add_argument( + "--tensor-include-dir", + action="store_true", + help="Path to dpctl libtensor include directory.", + ) parser.add_argument( "--cmakedir", action="store_true", @@ -99,6 +126,11 @@ def main() -> None: action="store_true", help="Linker flags for SyclInterface library.", ) + parser.add_argument( + "--library-dir", + action="store_true", + help="Path to directory containing DPCTLSyclInterface library", + ) parser.add_argument( "-f", "--full-list", @@ -139,13 +171,19 @@ def main() -> None: print_lsplatform(0) return if args.includes: - print_includes() + print_include_flags() + if args.include_dir: + print(get_include_dir()) if args.tensor_includes: - print_tensor_includes() + print_tensor_include_flags() + if args.tensor_include_dir: + print(get_tensor_include_dir()) if args.cmakedir: print_cmake_dir() if args.library: print_library() + if args.library_dir: + print(get_library_dir()) if __name__ == "__main__": diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index d2947aa772..d23142473e 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -156,6 +156,15 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_linalg_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp +) +set(_tensor_linalg_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp + ${_linalg_sources} +) set(_py_trgts) @@ -179,6 +188,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources} add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_linalg_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if (WIN32) set(_clang_prefix "/clang:") @@ -193,6 +207,7 @@ list(APPEND _no_fast_math_sources ${_elementwise_sources} ${_reduction_sources} ${_sorting_sources} + ${_linalg_sources} ) foreach(_src_fn ${_no_fast_math_sources}) @@ -208,7 +223,11 @@ set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") foreach(_src_fn ${_elementwise_sources}) get_source_file_property(_cmpl_options_defs ${_src_fn} COMPILE_DEFINITIONS) - set(_combined_options_defs ${_cmpl_options_defs} "${_compiler_definitions}") + if(${_cmpl_options_defs}) + set(_combined_options_defs ${_cmpl_options_defs} "${_compiler_definitions}") + else() + set(_combined_options_defs "${_compiler_definitions}") + endif() set_source_files_properties( ${_src_fn} PROPERTIES COMPILE_DEFINITIONS "${_combined_options_defs}" @@ -219,10 +238,6 @@ set(_linker_options "LINKER:${DPCTL_LDFLAGS}") foreach(python_module_name ${_py_trgts}) target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) - if(UNIX) - # this option is supported on Linux only - target_link_options(${python_module_name} PRIVATE -fsycl-link-huge-device-code) - endif() target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../include diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 81fc152e7a..77c4e23d8c 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -51,7 +51,6 @@ int16, int32, int64, - isdtype, uint8, uint16, uint32, @@ -60,7 +59,12 @@ from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take -from dpctl.tensor._linear_algebra_functions import matrix_transpose +from dpctl.tensor._linear_algebra_functions import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -183,7 +187,7 @@ ) from ._sorting import argsort, sort from ._testing import allclose -from ._type_utils import can_cast, finfo, iinfo, result_type +from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type __all__ = [ "Device", @@ -356,4 +360,7 @@ "unique_counts", "unique_inverse", "unique_values", + "matmul", + "tensordot", + "vecdot", ] diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index f2bc326e82..d95c0fa764 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -168,9 +168,9 @@ def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev): return dpt.dtype(ti.default_device_int_type(dev)) if isinstance(dtype, WeakComplexType): if st_dtype is dpt.float16 or st_dtype is dpt.float32: - return st_dtype, dpt.complex64 + return dpt.complex64 return _to_device_supported_dtype(dpt.complex128, dev) - return (_to_device_supported_dtype(dpt.float64, dev),) + return _to_device_supported_dtype(dpt.float64, dev) else: return st_dtype else: @@ -197,8 +197,6 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev): def _clip_none(x, val, out, order, _binary_fn): - if order not in ["K", "C", "F", "A"]: - order = "K" q1, x_usm_type = x.sycl_queue, x.usm_type q2, val_usm_type = _get_queue_usm_type(val) if q2 is None: @@ -391,9 +389,8 @@ def _clip_none(x, val, out, order, _binary_fn): return out -# need to handle logic for min or max being None -def clip(x, min=None, max=None, out=None, order="K"): - """clip(x, min, max, out=None, order="K") +def clip(x, /, min=None, max=None, out=None, order="K"): + """clip(x, min=None, max=None, out=None, order="K") Clips to the range [`min_i`, `max_i`] for each element `x_i` in `x`. @@ -402,14 +399,14 @@ def clip(x, min=None, max=None, out=None, order="K"): x (usm_ndarray): Array containing elements to clip. Must be compatible with `min` and `max` according to broadcasting rules. - min ({None, usm_ndarray}, optional): Array containing minimum values. + min ({None, Union[usm_ndarray, bool, int, float, complex]}, optional): + Array containing minimum values. Must be compatible with `x` and `max` according to broadcasting rules. - Only one of `min` and `max` can be `None`. - max ({None, usm_ndarray}, optional): Array containing maximum values. + max ({None, Union[usm_ndarray, bool, int, float, complex]}, optional): + Array containing maximum values. Must be compatible with `x` and `min` according to broadcasting rules. - Only one of `min` and `max` can be `None`. out ({None, usm_ndarray}, optional): Output array to populate. Array must have the correct shape and the expected data type. @@ -428,10 +425,67 @@ def clip(x, min=None, max=None, out=None, order="K"): "Expected `x` to be of dpctl.tensor.usm_ndarray type, got " f"{type(x)}" ) + if order not in ["K", "C", "F", "A"]: + order = "K" if min is None and max is None: - raise ValueError( - "only one of `min` and `max` is permitted to be `None`" + exec_q = x.sycl_queue + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + "output array must be of usm_ndarray type, got " + f"{type(out)}" + ) + + if out.shape != x.shape: + raise ValueError( + "The shape of input and output arrays are " + f"inconsistent. Expected output shape is {x.shape}, " + f"got {out.shape}" + ) + + if x.dtype != out.dtype: + raise ValueError( + f"Output array of type {x.dtype} is needed, " + f"got {out.dtype}" + ) + + if ( + dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) + is None + ): + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x, out): + if not ti._same_logical_tensors(x, out): + out = dpt.empty_like(out) + else: + return out + else: + if order == "K": + out = _empty_like_orderK(x, x.dtype) + else: + if order == "A": + order = "F" if x.flags.f_contiguous else "C" + out = dpt.empty_like(x, order=order) + + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x, dst=out, sycl_queue=exec_q ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + return out elif max is None: return _clip_none(x, min, out, order, tei._maximum) elif min is None: diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 81928692a6..ecf3eade35 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -26,6 +26,7 @@ import dpctl.utils from dpctl.tensor._data_types import _get_dtype from dpctl.tensor._device import normalize_queue_device +from dpctl.tensor._type_utils import _dtype_supported_by_device_impl __doc__ = ( "Implementation module for copy- and cast- operations on " @@ -121,7 +122,7 @@ def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None): output array is created. Device can be specified by a a filter selector string, an instance of :class:`dpctl.SyclDevice`, an instance of - :class:`dpctl.SyclQueue`, an instance of + :class:`dpctl.SyclQueue`, or an instance of :class:`dpctl.tensor.Device`. If the value is `None`, returned array is created on the default-selected device. Default: `None`. @@ -300,14 +301,22 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src): src.shape, src.strides, len(common_shape) ) src_same_shape = dpt.usm_ndarray( - common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides + common_shape, + dtype=src.dtype, + buffer=src, + strides=new_src_strides, + offset=src._element_offset, ) elif src.ndim == len(common_shape): new_src_strides = _broadcast_strides( src.shape, src.strides, len(common_shape) ) src_same_shape = dpt.usm_ndarray( - common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides + common_shape, + dtype=src.dtype, + buffer=src, + strides=new_src_strides, + offset=src._element_offset, ) else: # since broadcasting succeeded, src.ndim is greater because of @@ -523,7 +532,7 @@ def copy(usm_ary, order="K"): ) order = order[0].upper() if not isinstance(usm_ary, dpt.usm_ndarray): - return TypeError( + raise TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" ) copy_order = "C" @@ -556,9 +565,11 @@ def copy(usm_ary, order="K"): return R -def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): +def astype( + usm_ary, newdtype, /, order="K", casting="unsafe", *, copy=True, device=None +): """ astype(array, new_dtype, order="K", casting="unsafe", \ - copy=True) + copy=True, device=None) Returns a copy of the :class:`dpctl.tensor.usm_ndarray`, cast to a specified type. @@ -568,7 +579,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): An input array. new_dtype (dtype): The data type of the resulting array. If `None`, gives default - floating point type supported by device where `array` is allocated. + floating point type supported by device where the resulting array + will be located. order ({"C", "F", "A", "K"}, optional): Controls memory layout of the resulting array if a copy is returned. @@ -579,6 +591,14 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): By default, `astype` always returns a newly allocated array. If this keyword is set to `False`, a view of the input array may be returned when possible. + device (object): array API specification of device where the + output array is created. Device can be specified by a + a filter selector string, an instance of + :class:`dpctl.SyclDevice`, an instance of + :class:`dpctl.SyclQueue`, or an instance of + :class:`dpctl.tensor.Device`. If the value is `None`, + returned array is created on the same device as `array`. + Default: `None`. Returns: usm_ndarray: @@ -596,7 +616,25 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): ) order = order[0].upper() ary_dtype = usm_ary.dtype - target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) + if device is not None: + if not isinstance(device, dpctl.SyclQueue): + if isinstance(device, dpt.Device): + device = device.sycl_queue + else: + device = dpt.Device.create_device(device).sycl_queue + d = device.sycl_device + target_dtype = _get_dtype(newdtype, device) + if not _dtype_supported_by_device_impl( + target_dtype, d.has_aspect_fp16, d.has_aspect_fp64 + ): + raise ValueError( + f"Requested dtype `{target_dtype}` is not supported by the " + "target device" + ) + usm_ary = usm_ary.to_device(device) + else: + target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) + if not dpt.can_cast(ary_dtype, target_dtype, casting=casting): raise TypeError( f"Can not cast from {ary_dtype} to {newdtype} " diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index ba16f0f1fc..5c5c7279db 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -632,17 +632,13 @@ def asarray( usm_type=usm_type, order=order, ) - - raise NotImplementedError( - "Converting Python sequences is not implemented" - ) if copy is False: raise ValueError( f"Converting {type(obj)} to usm_ndarray requires a copy" ) # obj is a scalar, create 0d array return _asarray_from_numpy_ndarray( - np.asarray(obj), + np.asarray(obj, dtype=dtype), dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue, diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index bee557cf18..78e8714607 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -50,48 +50,6 @@ complex128 = dtype("complex128") -def isdtype(dtype_, kind): - """isdtype(dtype, kind) - - Returns a boolean indicating whether a provided `dtype` is - of a specified data type `kind`. - - See [array API](array_api) for more information. - - [array_api]: https://data-apis.org/array-api/latest/ - """ - - if not isinstance(dtype_, dtype): - raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype_}") - - if isinstance(kind, dtype): - return dtype_ == kind - - elif isinstance(kind, str): - if kind == "bool": - return dtype_ == dtype("bool") - elif kind == "signed integer": - return dtype_.kind == "i" - elif kind == "unsigned integer": - return dtype_.kind == "u" - elif kind == "integral": - return dtype_.kind in "iu" - elif kind == "real floating": - return dtype_.kind == "f" - elif kind == "complex floating": - return dtype_.kind == "c" - elif kind == "numeric": - return dtype_.kind in "iufc" - else: - raise ValueError(f"Unrecognized data type kind: {kind}") - - elif isinstance(kind, tuple): - return any(isdtype(dtype_, k) for k in kind) - - else: - raise TypeError(f"Unsupported data type kind: {kind}") - - def _get_dtype(inp_dt, sycl_obj, ref_type=None): """ Type inference utility to construct data type @@ -121,7 +79,6 @@ def _get_dtype(inp_dt, sycl_obj, ref_type=None): __all__ = [ "dtype", "_get_dtype", - "isdtype", "bool", "int8", "uint8", diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index fd2c58b08a..0894ac2077 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -14,7 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple + +import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_elementwise_impl as tei +import dpctl.tensor._tensor_impl as ti +import dpctl.tensor._tensor_linalg_impl as tli +from dpctl.tensor._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.tensor._type_utils import ( + _acceptance_fn_default_binary, + _find_buf_dtype2, + _to_device_supported_dtype, +) +from dpctl.utils import ExecutionPlacementError def matrix_transpose(x): @@ -46,3 +62,921 @@ def matrix_transpose(x): ) return x.mT + + +def tensordot(x1, x2, axes=2): + """tensordot(x1, x2, axes=2) + + Returns a tensor contraction of `x1` and `x2` over specific axes. + + Args: + x1 (usm_ndarray): + first input array, expected to have numeric data type. + x2 (usm_ndarray): + second input array, expected to have numeric data type. + Corresponding contracted axes of `x1` and `x2` must be equal. + axes (Union[int, Tuple[Sequence[int], Sequence[int]]): + number of axes to contract or explicit sequences of axes for + `x1` and `x2`, respectively. If `axes` is an integer equal to `N`, + then the contraction is performed over last `N` axes of `x1` and + the first `N` axis of `x2` in order. The size of each corresponding + axis must match and must be non-negative. + * if `N` equals `0`, the result is the tensor outer product + * if `N` equals `1`, the result is the tensor dot product + * if `N` equals `2`, the result is the tensor double + contraction (default). + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the + first sequence applies to `x1` and the second sequence applies + to `x2`. Both sequences must have equal length, and each axis + `x1_axes[i]` for `x1` must have the same size as the respective + axis `x2_axes[i]` for `x2`. Each sequence must consist of unique + non-negative integers that specify valid axes for each respective + array. + Returns: + usm_ndarray: + an array containing the tensor contraction whose shape consists of + the non-contracted axes of the first array `x1`, followed by the + non-contracted axes of the second array `x2`. The returned array + must have a data type determined by Type Promotion Rules. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # handle axes and shapes validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if isinstance(axes, int): + if axes < 0: + raise ValueError("`axes` integer is expected to be non-negative") + n_axes1 = axes + n_axes2 = axes + axes1 = normalize_axis_tuple(tuple(range(-axes, 0)), x1_nd) + axes2 = tuple(range(0, axes)) + elif isinstance(axes, tuple): + if len(axes) != 2: + raise ValueError( + "`axes` tuple is expected to contain two sequences" + ) + axes1 = tuple(axes[0]) + axes2 = tuple(axes[1]) + n_axes1 = len(axes1) + n_axes2 = len(axes2) + else: + raise TypeError("`axes` must be an integer or a tuple of sequences") + if n_axes1 != n_axes2: + raise ValueError( + "number of axes contracted must be the same for each array" + ) + if n_axes1 == 0: + arr1 = x1[..., dpt.newaxis] + arr2 = x2[dpt.newaxis, ...] + n_axes1 = 1 + n_axes2 = 1 + else: + same_shapes = True + for i in range(n_axes1): + axis1 = axes1[i] + if axis1 < 0: + raise ValueError("`axes` must be non-negative") + axis2 = axes2[i] + if axis2 < 0: + raise ValueError("`axes` must be non-negative") + same_shapes = same_shapes and (x1_shape[axis1] == x2_shape[axis2]) + if not same_shapes: + raise ValueError("shape mismatch in contracted `tensordot` axes") + axes1 = normalize_axis_tuple(axes1, x1_nd) + axes2 = normalize_axis_tuple(axes2, x2_nd) + perm1 = [i for i in range(x1_nd) if i not in axes1] + list(axes1) + perm2 = list(axes2) + [i for i in range(x2_nd) if i not in axes2] + arr1 = dpt.permute_dims(x1, perm1) + arr2 = dpt.permute_dims(x2, perm2) + arr1_outer_nd = arr1.ndim - n_axes1 + arr2_outer_nd = arr2.ndim - n_axes2 + res_shape = arr1.shape[:arr1_outer_nd] + arr2.shape[n_axes2:] + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'tensordot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + if buf1_dt is None and buf2_dt is None: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + ) + ht_dot_ev.wait() + + return out + + elif buf1_dt is None: + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + elif buf2_dt is None: + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + + return out + + +def vecdot(x1, x2, axis=-1): + """vecdot(x1, x2, axis=-1) + + Computes the (vector) dot product of two arrays. + + Args: + x1 (usm_ndarray): + first input array. + x2 (usm_ndarray): + second input array. Input arrays must have compatible + shapes along non-contract axes according to broadcasting + rules, and must have the same size along the contracted + axis. Input arrays should be of numeric type. + axis (Optional[int]): + axis over which to compute the dot product. The axis must + be an integer on the interval `[-N, N)`, where `N` is the + array rank of input arrays after broadcasting rules are + applied. If specified as a negative integer, the axis along + which dot product is performed is counted backward from + the last axes (that is `-1` refers to the last axis). By + default, dot product is computed over the last axis. + Default: `-1`. + + Returns: + usm_ndarray: + if `x1` and `x2` are both one-dimensional arrays, a + zero-dimensional array containing the dot product value + is returned; otherwise, a non-zero-dimensional array containing + the dot products and having rank `N-1`, where `N` is the rank + of the shape of input arrays after broadcasting rules are applied + to non-contracted axes. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # axis and shape validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if x1_nd > x2_nd: + x2_shape = (1,) * (x1_nd - x2_nd) + x2_shape + x2_nd = len(x2_shape) + elif x2_nd > x1_nd: + x1_shape = (1,) * (x2_nd - x1_nd) + x1_shape + x1_nd = len(x1_shape) + axis = normalize_axis_index(operator.index(axis), x1_nd) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError( + "given axis must have the same shape for `x1` and `x2`" + ) + try: + broadcast_sh = _broadcast_shape_impl( + [ + x1_shape, + x2_shape, + ] + ) + except ValueError: + raise ValueError("mismatch in `vecdot` dimensions") + res_sh = tuple( + [broadcast_sh[i] for i in range(len(broadcast_sh)) if i != axis] + ) + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'vecdot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + ht_list = [] + deps = [] + if buf1_dt is None and buf2_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_ev = tei._conj( + src=x1, + dst=x1_tmp, + sycl_queue=exec_q, + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + x1 = x1_tmp + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf1_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_e = tei._conj( + src=x1, dst=x1_tmp, sycl_queue=exec_q + ) + ht_list.append(ht_conj_ev) + deps.append(conj_e) + x1 = x1_tmp + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf2_dt is None: + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy1_ev) + deps.append(copy1_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy1_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + ht_list.append(ht_copy2_ev) + deps.append(copy2_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return out + + +def matmul(x1, x2, out=None, dtype=None, order="K"): + """matmul(x1, x2, out=None, order="K") + + Computes the matrix product. Implements the same semantics + as the built-in operator `@`. + + Args: + x1 (usm_ndarray): + first input array. Expected to have numeric data type, and + at least one dimension. If `x1` is one-dimensional having + shape `(M,)`, and `x2` has more than one dimension, `x1` is + effectively treated as a two-dimensional array with shape `(1, M)`, + although the prepended dimension is removed from the output array. + If `x1` has shape `(..., M, K)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + x2 (usm_ndarray): + second input array. Expected to have numeric data type, and + at least one dimension. If `x2` is one-dimensional having + shape `(N,)`, and `x1` has more than one dimension, `x2` is + effectively treated as a two-dimensional array with shape `(N, 1)`, + although the appended dimension is removed from the output array. + If `x2` has shape `(..., K, N)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + out (Optional[usm_ndarray]): + the array into which the result of the matrix product is written. + If `None` then a new array is returned. + order (["K", "C", "F", "A"]): + memory layout of the output array, if `out` is `None`, otherwise + the `order` parameter value is not used. + + Returns: + usm_ndarray: + * if both `x1` and `x2` are one-dimensional arrays with shape + `(N,)`, returned array is a zero-dimensional array containing + inner product as its only element. + * if `x1` is two-dimensional array with shape `(M, K)` and `x2` is + a two-dimensional array with shape `(K, N)`, returned array is a + two-dimensional array with shape `(M, N)` and contains the + conventional matrix product. + * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an + array with shape `(..., K, N)`, returned array contains the + conventional matrix product and has shape `(..., N)`. + * if `x1` is an array with shape `(..., M, K)` and `x2` is a + one-dimensional array with shape `(K,)`, returned array has shape + `(..., M)` and contains the conventional matrix product. + * if `x1` is a two-dimensional array with shape `(M, K)` and `x2` + is an array with shape `(..., K, N)`, returned array contains + conventional matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if `x1` has shape `(..., M, K)` and `x2` is a two-dimensional + array with shape `(K, N)`, returned array contains conventional + matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if both `x1` and `x2` have more than two dimensions, returned + array contains conventional matrix product for each stacked + matrix and has shape determined by broadcasting rules for + `x1.shape[:-2]` and `x2.shape[:-2]`. + + The data type of the returned array is determined by the Type + Promotion Rules. If either `x1` or `x2` has a complex floating + point type, neither argument is complex conjugated or transposed. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + + x1_nd = x1.ndim + x2_nd = x2.ndim + if x1_nd == 0 or x2_nd == 0: + raise ValueError("one or more operands to `matmul` is 0 dimensional") + x1_shape = x1.shape + x2_shape = x2.shape + appended_axes = [] + if x1_nd == 1: + x1 = x1[dpt.newaxis, :] + x1_shape = x1.shape + appended_axes.append(-2) + if x2_nd == 1: + x2 = x2[:, dpt.newaxis] + x2_shape = x2.shape + appended_axes.append(-1) + if x1_shape[-1] != x2_shape[-2]: + raise ValueError("mismatch in `matmul` inner dimension") + x1_outer_sh = x1_shape[:-2] + x2_outer_sh = x2_shape[:-2] + try: + res_outer_sh = _broadcast_shape_impl( + [ + x1_outer_sh, + x2_outer_sh, + ] + ) + except ValueError: + raise ValueError("mismatch in `matmul` batching dimensions") + x1_broadcast_shape = res_outer_sh + x1_shape[-2:] + x2_broadcast_shape = res_outer_sh + x2_shape[-2:] + res_shape = res_outer_sh + x1_shape[-2:-1] + x2_shape[-1:] + + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + if dtype is None: + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise ValueError( + "function 'matmul' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, sycl_dev) + buf1_dt, buf2_dt = None, None + if x1_dtype != res_dt: + if dpt.can_cast(x1_dtype, res_dt, casting="same_kind"): + buf1_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x1` cannot be cast from {x1_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + if x2_dtype != res_dt: + if dpt.can_cast(x2_dtype, res_dt, casting="same_kind"): + buf2_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x2` cannot be cast from {x2_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed," f"got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x1, out) and buf1_dt is None: + out = dpt.empty_like(out) + + if ti._array_overlap(x2, out) and buf2_dt is None: + # should not reach if out is reallocated + # after being checked against x1 + out = dpt.empty_like(out) + + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x1, + x2, + ) + ) + else "C" + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x1.shape != x1_broadcast_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if x1.shape != x1_broadcast_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + if order in ["K", "A"]: + if x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + elif x1.flags.c_contiguous and x2.flags.c_contiguous: + order = "C" + else: + order = "C" if order == "A" else "K" + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 3021db1841..144215e2d6 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import numpy as np @@ -662,6 +663,48 @@ def _supported_dtype(dtypes): return True +def isdtype(dtype, kind): + """isdtype(dtype, kind) + + Returns a boolean indicating whether a provided `dtype` is + of a specified data type `kind`. + + See [array API](array_api) for more information. + + [array_api]: https://data-apis.org/array-api/latest/ + """ + + if not isinstance(dtype, np.dtype): + raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype}") + + if isinstance(kind, np.dtype): + return dtype == kind + + elif isinstance(kind, str): + if kind == "bool": + return dtype == np.dtype("bool") + elif kind == "signed integer": + return dtype.kind == "i" + elif kind == "unsigned integer": + return dtype.kind == "u" + elif kind == "integral": + return dtype.kind in "iu" + elif kind == "real floating": + return dtype.kind == "f" + elif kind == "complex floating": + return dtype.kind == "c" + elif kind == "numeric": + return dtype.kind in "iufc" + else: + raise ValueError(f"Unrecognized data type kind: {kind}") + + elif isinstance(kind, tuple): + return any(isdtype(dtype, k) for k in kind) + + else: + raise TypeError(f"Unsupported data type kind: {kind}") + + __all__ = [ "_find_buf_dtype", "_find_buf_dtype2", @@ -676,6 +719,7 @@ def _supported_dtype(dtypes): "can_cast", "finfo", "iinfo", + "isdtype", "result_type", "WeakBooleanType", "WeakIntegralType", diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 284de1cbe1..67e144f798 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -907,15 +907,15 @@ cdef class usm_ndarray: def __abs__(self): return dpctl.tensor.abs(self) - def __add__(first, other): + def __add__(self, other): """ Implementation for operator.add """ - return dpctl.tensor.add(first, other) + return dpctl.tensor.add(self, other) - def __and__(first, other): + def __and__(self, other): "Implementation for operator.and" - return dpctl.tensor.bitwise_and(first, other) + return dpctl.tensor.bitwise_and(self, other) def __dlpack__(self, stream=None): """ @@ -963,8 +963,8 @@ cdef class usm_ndarray: def __eq__(self, other): return dpctl.tensor.equal(self, other) - def __floordiv__(first, other): - return dpctl.tensor.floor_divide(first, other) + def __floordiv__(self, other): + return dpctl.tensor.floor_divide(self, other) def __ge__(self, other): return dpctl.tensor.greater_equal(self, other) @@ -984,21 +984,20 @@ cdef class usm_ndarray: else: raise TypeError("len() of unsized object") - def __lshift__(first, other): - "See comment in __add__" - return dpctl.tensor.bitwise_left_shift(first, other) + def __lshift__(self, other): + return dpctl.tensor.bitwise_left_shift(self, other) def __lt__(self, other): return dpctl.tensor.less(self, other) - def __matmul__(first, other): - return NotImplemented + def __matmul__(self, other): + return dpctl.tensor.matmul(self, other) - def __mod__(first, other): - return dpctl.tensor.remainder(first, other) + def __mod__(self, other): + return dpctl.tensor.remainder(self, other) - def __mul__(first, other): - return dpctl.tensor.multiply(first, other) + def __mul__(self, other): + return dpctl.tensor.multiply(self, other) def __ne__(self, other): return dpctl.tensor.not_equal(self, other) @@ -1006,20 +1005,17 @@ cdef class usm_ndarray: def __neg__(self): return dpctl.tensor.negative(self) - def __or__(first, other): - return dpctl.tensor.bitwise_or(first, other) + def __or__(self, other): + return dpctl.tensor.bitwise_or(self, other) def __pos__(self): return dpctl.tensor.positive(self) - def __pow__(first, other, mod): - if mod is None: - return dpctl.tensor.pow(first, other) - else: - return NotImplemented + def __pow__(self, other): + return dpctl.tensor.pow(self, other) - def __rshift__(first, other): - return dpctl.tensor.bitwise_right_shift(first, other) + def __rshift__(self, other): + return dpctl.tensor.bitwise_right_shift(self, other) def __setitem__(self, key, rhs): cdef tuple _meta @@ -1109,14 +1105,14 @@ cdef class usm_ndarray: return - def __sub__(first, other): - return dpctl.tensor.subtract(first, other) + def __sub__(self, other): + return dpctl.tensor.subtract(self, other) - def __truediv__(first, other): - return dpctl.tensor.divide(first, other) + def __truediv__(self, other): + return dpctl.tensor.divide(self, other) - def __xor__(first, other): - return dpctl.tensor.bitwise_xor(first, other) + def __xor__(self, other): + return dpctl.tensor.bitwise_xor(self, other) def __radd__(self, other): return dpctl.tensor.add(other, self) @@ -1131,7 +1127,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(other, self) def __rmatmul__(self, other): - return NotImplemented + return dpctl.tensor.matmul(other, self) def __rmod__(self, other): return dpctl.tensor.remainder(other, self) @@ -1170,11 +1166,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(self, other, out=self) def __imatmul__(self, other): - res = self.__matmul__(other) - if res is NotImplemented: - return res - self.__setitem__(Ellipsis, res) - return self + return dpctl.tensor.matmul(self, other, out=self) def __imod__(self, other): return dpctl.tensor.remainder(self, other, out=self) diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index 491fb12126..ed06d9a774 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -26,13 +26,13 @@ #include #include #include -#include #include #include #include +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace dpctl { @@ -43,8 +43,6 @@ namespace kernels namespace accumulators { -namespace py = pybind11; - using namespace dpctl::tensor::offset_utils; template T ceiling_quotient(T n, T m) @@ -437,7 +435,7 @@ typedef size_t (*accumulate_strided_impl_fn_ptr_t)( size_t, const char *, int, - const py::ssize_t *, + const ssize_t *, char *, std::vector &, const std::vector &); @@ -447,7 +445,7 @@ size_t accumulate_strided_impl(sycl::queue &q, size_t n_elems, const char *mask, int nd, - const py::ssize_t *shape_strides, + const ssize_t *shape_strides, char *cumsum, std::vector &host_tasks, const std::vector &depends = {}) diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp index 522baadc6d..46468de2e0 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp @@ -25,13 +25,13 @@ #pragma once #include #include -#include #include #include #include +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" namespace dpctl { @@ -42,8 +42,6 @@ namespace kernels namespace indexing { -namespace py = pybind11; - using namespace dpctl::tensor::offset_utils; template (orthog_i)); + orthog_src_dst_indexer(static_cast(orthog_i)); size_t total_src_offset = masked_src_indexer(masked_i) + orthog_offsets.get_first_offset(); @@ -161,7 +159,7 @@ struct MaskedPlaceStridedFunctor // + 1 : 1) if (mask_set) { auto orthog_offsets = - orthog_dst_rhs_indexer(static_cast(orthog_i)); + orthog_dst_rhs_indexer(static_cast(orthog_i)); size_t total_dst_offset = masked_dst_indexer(masked_i) + orthog_offsets.get_first_offset(); @@ -199,28 +197,28 @@ class masked_extract_all_slices_strided_impl_krn; typedef sycl::event (*masked_extract_all_slices_strided_impl_fn_ptr_t)( sycl::queue &, - py::ssize_t, + ssize_t, const char *, const char *, char *, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, const std::vector &); template sycl::event masked_extract_all_slices_strided_impl( sycl::queue &exec_q, - py::ssize_t iteration_size, + ssize_t iteration_size, const char *src_p, const char *cumsum_p, char *dst_p, int nd, - const py::ssize_t + const ssize_t *packed_src_shape_strides, // [src_shape, src_strides], length 2*nd - py::ssize_t dst_size, // dst is 1D - py::ssize_t dst_stride, + ssize_t dst_size, // dst is 1D + ssize_t dst_stride, const std::vector &depends = {}) { // using MaskedExtractStridedFunctor; @@ -230,7 +228,7 @@ sycl::event masked_extract_all_slices_strided_impl( TwoZeroOffsets_Indexer orthog_src_dst_indexer{}; - /* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const + /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const * *_packed_shape_strides) */ StridedIndexer masked_src_indexer(nd, 0, packed_src_shape_strides); Strided1DIndexer masked_dst_indexer(0, dst_size, dst_stride); @@ -254,19 +252,19 @@ sycl::event masked_extract_all_slices_strided_impl( typedef sycl::event (*masked_extract_some_slices_strided_impl_fn_ptr_t)( sycl::queue &, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, const char *, const char *, char *, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, const std::vector &); template sycl::event masked_extract_some_slices_strided_impl( sycl::queue &exec_q, - py::ssize_t orthog_nelems, - py::ssize_t masked_nelems, + ssize_t orthog_nelems, + ssize_t masked_nelems, const char *src_p, const char *cumsum_p, char *dst_p, int orthog_nd, - const py::ssize_t + const ssize_t *packed_ortho_src_dst_shape_strides, // [ortho_shape, ortho_src_strides, // ortho_dst_strides], length // 3*ortho_nd - py::ssize_t ortho_src_offset, - py::ssize_t ortho_dst_offset, + ssize_t ortho_src_offset, + ssize_t ortho_dst_offset, int masked_nd, - const py::ssize_t *packed_masked_src_shape_strides, // [masked_src_shape, - // masked_src_strides], - // length 2*masked_nd - py::ssize_t masked_dst_size, // mask_dst is 1D - py::ssize_t masked_dst_stride, + const ssize_t *packed_masked_src_shape_strides, // [masked_src_shape, + // masked_src_strides], + // length 2*masked_nd + ssize_t masked_dst_size, // mask_dst is 1D + ssize_t masked_dst_stride, const std::vector &depends = {}) { // using MaskedExtractStridedFunctor; @@ -381,33 +379,33 @@ class masked_place_all_slices_strided_impl_krn; typedef sycl::event (*masked_place_all_slices_strided_impl_fn_ptr_t)( sycl::queue &, - py::ssize_t, + ssize_t, char *, const char *, const char *, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, const std::vector &); template sycl::event masked_place_all_slices_strided_impl( sycl::queue &exec_q, - py::ssize_t iteration_size, + ssize_t iteration_size, char *dst_p, const char *cumsum_p, const char *rhs_p, int nd, - const py::ssize_t + const ssize_t *packed_dst_shape_strides, // [dst_shape, dst_strides], length 2*nd - py::ssize_t rhs_size, // rhs is 1D - py::ssize_t rhs_stride, + ssize_t rhs_size, // rhs is 1D + ssize_t rhs_stride, const std::vector &depends = {}) { TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{}; - /* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const + /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const * *_packed_shape_strides) */ StridedIndexer masked_dst_indexer(nd, 0, packed_dst_shape_strides); Strided1DCyclicIndexer masked_rhs_indexer(0, rhs_size, rhs_stride); @@ -431,19 +429,19 @@ sycl::event masked_place_all_slices_strided_impl( typedef sycl::event (*masked_place_some_slices_strided_impl_fn_ptr_t)( sycl::queue &, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, char *, const char *, const char *, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, int, - py::ssize_t const *, - py::ssize_t, - py::ssize_t, + ssize_t const *, + ssize_t, + ssize_t, const std::vector &); template sycl::event masked_place_some_slices_strided_impl( sycl::queue &exec_q, - py::ssize_t orthog_nelems, - py::ssize_t masked_nelems, + ssize_t orthog_nelems, + ssize_t masked_nelems, char *dst_p, const char *cumsum_p, const char *rhs_p, int orthog_nd, - const py::ssize_t + const ssize_t *packed_ortho_dst_rhs_shape_strides, // [ortho_shape, ortho_dst_strides, // ortho_rhs_strides], length // 3*ortho_nd - py::ssize_t ortho_dst_offset, - py::ssize_t ortho_rhs_offset, + ssize_t ortho_dst_offset, + ssize_t ortho_rhs_offset, int masked_nd, - const py::ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape, - // masked_dst_strides], - // length 2*masked_nd - py::ssize_t masked_rhs_size, // mask_dst is 1D - py::ssize_t masked_rhs_stride, + const ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape, + // masked_dst_strides], + // length 2*masked_nd + ssize_t masked_rhs_size, // mask_dst is 1D + ssize_t masked_rhs_stride, const std::vector &depends = {}) { TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{ orthog_nd, ortho_dst_offset, ortho_rhs_offset, packed_ortho_dst_rhs_shape_strides}; - /* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const + /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const * *_packed_shape_strides) */ StridedIndexer masked_dst_indexer{masked_nd, 0, packed_masked_dst_shape_strides}; @@ -550,22 +548,22 @@ template class non_zero_indexes_krn; typedef sycl::event (*non_zero_indexes_fn_ptr_t)( sycl::queue &, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, int, const char *, char *, - const py::ssize_t *, + const ssize_t *, std::vector const &); template sycl::event non_zero_indexes_impl(sycl::queue &exec_q, - py::ssize_t iter_size, - py::ssize_t nz_elems, + ssize_t iter_size, + ssize_t nz_elems, int nd, const char *cumsum_cp, char *indexes_cp, - const py::ssize_t *mask_shape, + const ssize_t *mask_shape, std::vector const &depends) { const indT1 *cumsum_data = reinterpret_cast(cumsum_cp); @@ -582,11 +580,11 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q, auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0); bool cond = (cs_curr_val == cs_prev_val); - py::ssize_t i_ = static_cast(i); + ssize_t i_ = static_cast(i); for (int dim = nd; --dim > 0;) { auto sd = mask_shape[dim]; - py::ssize_t q = i_ / sd; - py::ssize_t r = (i_ - q * sd); + ssize_t q = i_ / sd; + ssize_t r = (i_ - q * sd); if (cond) { indexes_data[cs_curr_val + dim * nz_elems] = static_cast(r); diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp index 877680c8bf..ee64bd2e44 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -31,15 +31,12 @@ #include #include -#include "pybind11/pybind11.h" - +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -namespace py = pybind11; - namespace dpctl { namespace tensor @@ -179,16 +176,16 @@ struct SequentialBooleanReduction { auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); - const py::ssize_t &inp_iter_offset = + const ssize_t &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const py::ssize_t &out_iter_offset = + const ssize_t &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); outT red_val(identity_); for (size_t m = 0; m < reduction_max_gid_; ++m) { - py::ssize_t inp_reduction_offset = - static_cast(inp_reduced_dims_indexer_(m)); - py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + ssize_t inp_reduction_offset = + static_cast(inp_reduced_dims_indexer_(m)); + ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; // must convert to boolean first to handle nans using dpctl::tensor::type_utils::convert_impl; @@ -249,9 +246,9 @@ typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)( size_t, const char *, char *, - py::ssize_t, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template @@ -269,9 +266,9 @@ boolean_reduction_axis1_contig_impl(sycl::queue &exec_q, size_t reduction_nelems, const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t red_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t red_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -298,8 +295,8 @@ boolean_reduction_axis1_contig_impl(sycl::queue &exec_q, using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -425,9 +422,9 @@ struct StridedBooleanReduction const size_t wg_size = it.get_local_range(0); auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id); - const py::ssize_t &inp_iter_offset = + const ssize_t &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const py::ssize_t &out_iter_offset = + const ssize_t &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); outT local_red_val(identity_); @@ -438,9 +435,9 @@ struct StridedBooleanReduction for (size_t arg_reduce_gid = arg_reduce_gid0; arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size) { - py::ssize_t inp_reduction_offset = static_cast( - inp_reduced_dims_indexer_(arg_reduce_gid)); - py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + ssize_t inp_reduction_offset = + static_cast(inp_reduced_dims_indexer_(arg_reduce_gid)); + ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; // must convert to boolean first to handle nans using dpctl::tensor::type_utils::convert_impl; @@ -470,9 +467,9 @@ boolean_reduction_axis0_contig_impl(sycl::queue &exec_q, size_t reduction_nelems, const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t red_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t red_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -507,8 +504,8 @@ boolean_reduction_axis0_contig_impl(sycl::queue &exec_q, InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, result_indexer}; ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; constexpr size_t preferred_reductions_per_wi = 4; size_t reductions_per_wi = @@ -582,12 +579,12 @@ typedef sycl::event (*boolean_reduction_strided_impl_fn_ptr)( const char *, char *, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, int, - const py::ssize_t *, - py::ssize_t, + const ssize_t *, + ssize_t, const std::vector &); template @@ -598,12 +595,12 @@ boolean_reduction_strided_impl(sycl::queue &exec_q, const char *arg_cp, char *res_cp, int iter_nd, - const py::ssize_t *iter_shape_and_strides, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, int red_nd, - const py::ssize_t *reduction_shape_stride, - py::ssize_t reduction_arg_offset, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp); @@ -647,8 +644,8 @@ boolean_reduction_strided_impl(sycl::queue &exec_q, using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; - const py::ssize_t *const &res_shape = iter_shape_and_strides; - const py::ssize_t *const &res_strides = + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = iter_shape_and_strides + 2 * iter_nd; IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); diff --git a/dpctl/tensor/libtensor/include/kernels/clip.hpp b/dpctl/tensor/libtensor/include/kernels/clip.hpp index aff1acb071..6d9bae6ed5 100644 --- a/dpctl/tensor/libtensor/include/kernels/clip.hpp +++ b/dpctl/tensor/libtensor/include/kernels/clip.hpp @@ -23,19 +23,16 @@ //===----------------------------------------------------------------------===// #pragma once -#include "pybind11/numpy.h" -#include "pybind11/stl.h" -#include #include #include #include -#include +#include #include +#include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" namespace dpctl @@ -47,9 +44,6 @@ namespace kernels namespace clip { -namespace py = pybind11; -namespace td_ns = dpctl::tensor::type_dispatch; - using namespace dpctl::tensor::offset_utils; using dpctl::tensor::kernels::alignment_utils:: @@ -257,7 +251,7 @@ template class ClipStridedFunctor void operator()(sycl::id<1> id) const { size_t gid = id[0]; - auto offsets = indexer(static_cast(gid)); + auto offsets = indexer(static_cast(gid)); dst_p[offsets.get_fourth_offset()] = clip( x_p[offsets.get_first_offset()], min_p[offsets.get_second_offset()], max_p[offsets.get_third_offset()]); @@ -274,11 +268,11 @@ typedef sycl::event (*clip_strided_impl_fn_ptr_t)( const char *, const char *, char *, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template @@ -289,11 +283,11 @@ sycl::event clip_strided_impl(sycl::queue &q, const char *min_cp, const char *max_cp, char *dst_cp, - const py::ssize_t *shape_strides, - py::ssize_t x_offset, - py::ssize_t min_offset, - py::ssize_t max_offset, - py::ssize_t dst_offset, + const ssize_t *shape_strides, + ssize_t x_offset, + ssize_t min_offset, + ssize_t max_offset, + ssize_t dst_offset, const std::vector &depends) { const T *x_tp = reinterpret_cast(x_cp); diff --git a/dpctl/tensor/libtensor/include/kernels/constructors.hpp b/dpctl/tensor/libtensor/include/kernels/constructors.hpp index c28033d23d..4cab7c213c 100644 --- a/dpctl/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpctl/tensor/libtensor/include/kernels/constructors.hpp @@ -24,11 +24,11 @@ //===----------------------------------------------------------------------===// #pragma once +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" #include "utils/strided_iters.hpp" #include "utils/type_utils.hpp" #include -#include #include namespace dpctl @@ -48,37 +48,8 @@ template class linear_sequence_step_kernel; template class linear_sequence_affine_kernel; template class eye_kernel; -namespace py = pybind11; using namespace dpctl::tensor::offset_utils; -/* =========== Unboxing Python scalar =============== */ - -/*! - * @brief Cast pybind11 class managing Python object to specified type `T`. - * @defgroup CtorKernels - */ -template T unbox_py_scalar(const py::object &o) -{ - return py::cast(o); -} - -template <> inline sycl::half unbox_py_scalar(const py::object &o) -{ - float tmp = py::cast(o); - return static_cast(tmp); -} - -// Constructor to populate tensor with linear sequence defined by -// start and step data - -typedef sycl::event (*lin_space_step_fn_ptr_t)( - sycl::queue &, - size_t, // num_elements - const py::object &start, - const py::object &step, - char *, // dst_data_ptr - const std::vector &); - template class LinearSequenceStepFunctor { private: @@ -142,74 +113,9 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q, return lin_space_step_event; } -/*! - * @brief Function to submit kernel to populate given contiguous memory - * allocation with linear sequence specified by starting value and increment - * given as Python objects. - * - * @param q Sycl queue to which the kernel is submitted - * @param nelems Length of the sequence - * @param start Starting value of the sequence as Python object. Must be - * convertible to array element data type `Ty`. - * @param step Increment of the sequence as Python object. Must be convertible - * to array element data type `Ty`. - * @param array_data Kernel accessible USM pointer to the start of array to be - * populated. - * @param depends List of events to wait for before starting computations, if - * any. - * - * @return Event to wait on to ensure that computation completes. - * @defgroup CtorKernels - */ -template -sycl::event lin_space_step_impl(sycl::queue &exec_q, - size_t nelems, - const py::object &start, - const py::object &step, - char *array_data, - const std::vector &depends) -{ - Ty start_v; - Ty step_v; - try { - start_v = unbox_py_scalar(start); - step_v = unbox_py_scalar(step); - } catch (const py::error_already_set &e) { - throw; - } - - auto lin_space_step_event = lin_space_step_impl( - exec_q, nelems, start_v, step_v, array_data, depends); - - return lin_space_step_event; -} - -/*! - * @brief Factor to get function pointer of type `fnT` for array with elements - * of type `Ty`. - * @defgroup CtorKernels - */ -template struct LinSpaceStepFactory -{ - fnT get() - { - fnT f = lin_space_step_impl; - return f; - } -}; - // Constructor to populate tensor with linear sequence defined by // start and and data -typedef sycl::event (*lin_space_affine_fn_ptr_t)( - sycl::queue &, - size_t, // num_elements - const py::object &start, - const py::object &end, - bool include_endpoint, - char *, // dst_data_ptr - const std::vector &); - template class LinearSequenceAffineFunctor { private: @@ -312,70 +218,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q, return lin_space_affine_event; } -/*! - * @brief Function to submit kernel to populate given contiguous memory - * allocation with linear sequence specified by starting and end values given - * as Python objects. - * - * @param exec_q Sycl queue to which kernel is submitted for execution. - * @param nelems Length of the sequence - * @param start Stating value of the sequence as Python object. Must be - * convertible to array data element type `Ty`. - * @param end End-value of the sequence as Python object. Must be convertible - * to array data element type `Ty`. - * @param include_endpoint Whether the end-value is included in the sequence - * @param array_data Kernel accessible USM pointer to the start of array to be - * populated. - * @param depends List of events to wait for before starting computations, if - * any. - * - * @return Event to wait on to ensure that computation completes. - * @defgroup CtorKernels - */ -template -sycl::event lin_space_affine_impl(sycl::queue &exec_q, - size_t nelems, - const py::object &start, - const py::object &end, - bool include_endpoint, - char *array_data, - const std::vector &depends) -{ - Ty start_v, end_v; - try { - start_v = unbox_py_scalar(start); - end_v = unbox_py_scalar(end); - } catch (const py::error_already_set &e) { - throw; - } - - auto lin_space_affine_event = lin_space_affine_impl( - exec_q, nelems, start_v, end_v, include_endpoint, array_data, depends); - - return lin_space_affine_event; -} - -/*! - * @brief Factory to get function pointer of type `fnT` for array data type - * `Ty`. - */ -template struct LinSpaceAffineFactory -{ - fnT get() - { - fnT f = lin_space_affine_impl; - return f; - } -}; - /* ================ Full ================== */ -typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &, - size_t, - const py::object &, - char *, - const std::vector &); - /*! * @brief Function to submit kernel to fill given contiguous memory allocation * with specified value. @@ -408,58 +252,13 @@ sycl::event full_contig_impl(sycl::queue &q, return fill_ev; } -/*! - * @brief Function to submit kernel to fill given contiguous memory allocation - * with specified value. - * - * @param exec_q Sycl queue to which kernel is submitted for execution. - * @param nelems Length of the sequence - * @param py_value Python object representing the value to fill the array with. - * Must be convertible to `dstTy`. - * @param dst_p Kernel accessible USM pointer to the start of array to be - * populated. - * @param depends List of events to wait for before starting computations, if - * any. - * - * @return Event to wait on to ensure that computation completes. - * @defgroup CtorKernels - */ -template -sycl::event full_contig_impl(sycl::queue &exec_q, - size_t nelems, - const py::object &py_value, - char *dst_p, - const std::vector &depends) -{ - dstTy fill_v; - try { - fill_v = unbox_py_scalar(py_value); - } catch (const py::error_already_set &e) { - throw; - } - - sycl::event fill_ev = - full_contig_impl(exec_q, nelems, fill_v, dst_p, depends); - - return fill_ev; -} - -template struct FullContigFactory -{ - fnT get() - { - fnT f = full_contig_impl; - return f; - } -}; - /* ================ Eye ================== */ typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &, size_t nelems, // num_elements - py::ssize_t start, - py::ssize_t end, - py::ssize_t step, + ssize_t start, + ssize_t end, + ssize_t step, char *, // dst_data_ptr const std::vector &); @@ -467,15 +266,15 @@ template class EyeFunctor { private: Ty *p = nullptr; - py::ssize_t start_v; - py::ssize_t end_v; - py::ssize_t step_v; + ssize_t start_v; + ssize_t end_v; + ssize_t step_v; public: EyeFunctor(char *dst_p, - const py::ssize_t v0, - const py::ssize_t v1, - const py::ssize_t dv) + const ssize_t v0, + const ssize_t v1, + const ssize_t dv) : p(reinterpret_cast(dst_p)), start_v(v0), end_v(v1), step_v(dv) { } @@ -483,7 +282,7 @@ template class EyeFunctor void operator()(sycl::id<1> wiid) const { Ty set_v = 0; - py::ssize_t i = static_cast(wiid.get(0)); + ssize_t i = static_cast(wiid.get(0)); if (i >= start_v and i <= end_v) { if ((i - start_v) % step_v == 0) { set_v = 1; @@ -511,9 +310,9 @@ template class EyeFunctor template sycl::event eye_impl(sycl::queue &exec_q, size_t nelems, - const py::ssize_t start, - const py::ssize_t end, - const py::ssize_t step, + const ssize_t start, + const ssize_t end, + const ssize_t step, char *array_data, const std::vector &depends) { @@ -545,13 +344,13 @@ template struct EyeFactory // define function type typedef sycl::event (*tri_fn_ptr_t)(sycl::queue &, - py::ssize_t, // inner_range //py::ssize_t - py::ssize_t, // outer_range - char *, // src_data_ptr - char *, // dst_data_ptr - py::ssize_t, // nd - py::ssize_t *, // shape_and_strides - py::ssize_t, // k + ssize_t, // inner_range //ssize_t + ssize_t, // outer_range + char *, // src_data_ptr + char *, // dst_data_ptr + ssize_t, // nd + ssize_t *, // shape_and_strides + ssize_t, // k const std::vector &, const std::vector &); @@ -580,21 +379,21 @@ typedef sycl::event (*tri_fn_ptr_t)(sycl::queue &, template class tri_kernel; template sycl::event tri_impl(sycl::queue &exec_q, - py::ssize_t inner_range, - py::ssize_t outer_range, + ssize_t inner_range, + ssize_t outer_range, char *src_p, char *dst_p, - py::ssize_t nd, - py::ssize_t *shape_and_strides, - py::ssize_t k, + ssize_t nd, + ssize_t *shape_and_strides, + ssize_t k, const std::vector &depends, const std::vector &additional_depends) { constexpr int d2 = 2; - py::ssize_t src_s = nd; - py::ssize_t dst_s = 2 * nd; - py::ssize_t nd_1 = nd - 1; - py::ssize_t nd_2 = nd - 2; + ssize_t src_s = nd; + ssize_t dst_s = 2 * nd; + ssize_t nd_1 = nd - 1; + ssize_t nd_2 = nd - 2; Ty *src = reinterpret_cast(src_p); Ty *dst = reinterpret_cast(dst_p); @@ -606,18 +405,18 @@ sycl::event tri_impl(sycl::queue &exec_q, cgh.parallel_for>( sycl::range<1>(inner_range * outer_range), [=](sycl::id<1> idx) { - py::ssize_t outer_gid = idx[0] / inner_range; - py::ssize_t inner_gid = idx[0] - inner_range * outer_gid; + ssize_t outer_gid = idx[0] / inner_range; + ssize_t inner_gid = idx[0] - inner_range * outer_gid; - py::ssize_t src_inner_offset = 0, dst_inner_offset = 0; + ssize_t src_inner_offset = 0, dst_inner_offset = 0; bool to_copy(true); { using dpctl::tensor::strides::CIndexer_array; - CIndexer_array indexer_i( + CIndexer_array indexer_i( {shape_and_strides[nd_2], shape_and_strides[nd_1]}); indexer_i.set(inner_gid); - const std::array &inner = indexer_i.get(); + const std::array &inner = indexer_i.get(); src_inner_offset = inner[0] * shape_and_strides[src_s + nd_2] + inner[1] * shape_and_strides[src_s + nd_1]; @@ -631,11 +430,11 @@ sycl::event tri_impl(sycl::queue &exec_q, to_copy = (inner[0] + k <= inner[1]); } - py::ssize_t src_offset = 0; - py::ssize_t dst_offset = 0; + ssize_t src_offset = 0; + ssize_t dst_offset = 0; { using dpctl::tensor::strides::CIndexer_vector; - CIndexer_vector outer(nd - d2); + CIndexer_vector outer(nd - d2); outer.get_displacement( outer_gid, shape_and_strides, shape_and_strides + src_s, shape_and_strides + dst_s, src_offset, dst_offset); diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index ef24b58ef2..9bf86e560b 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -25,10 +25,10 @@ #pragma once #include #include -#include #include #include +#include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/offset_utils.hpp" #include "utils/type_utils.hpp" @@ -42,7 +42,6 @@ namespace kernels namespace copy_and_cast { -namespace py = pybind11; using namespace dpctl::tensor::offset_utils; using dpctl::tensor::kernels::alignment_utils:: @@ -89,9 +88,9 @@ class GenericCopyFunctor void operator()(sycl::id<1> wiid) const { - const auto &offsets = indexer_(static_cast(wiid.get(0))); - const py::ssize_t &src_offset = offsets.get_first_offset(); - const py::ssize_t &dst_offset = offsets.get_second_offset(); + const auto &offsets = indexer_(static_cast(wiid.get(0))); + const ssize_t &src_offset = offsets.get_first_offset(); + const ssize_t &dst_offset = offsets.get_second_offset(); CastFnT fn{}; dst_[dst_offset] = fn(src_[src_offset]); @@ -109,11 +108,11 @@ typedef sycl::event (*copy_and_cast_generic_fn_ptr_t)( sycl::queue &, size_t, int, - const py::ssize_t *, + const ssize_t *, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &, const std::vector &); @@ -155,11 +154,11 @@ sycl::event copy_and_cast_generic_impl(sycl::queue &q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *src_p, - py::ssize_t src_offset, + ssize_t src_offset, char *dst_p, - py::ssize_t dst_offset, + ssize_t dst_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -389,13 +388,13 @@ template struct CopyAndCastContigFactory typedef sycl::event (*copy_and_cast_1d_fn_ptr_t)( sycl::queue &, size_t, - const std::array, - const std::array, - const std::array, + const std::array, + const std::array, + const std::array, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); /*! @@ -405,13 +404,13 @@ typedef sycl::event (*copy_and_cast_1d_fn_ptr_t)( typedef sycl::event (*copy_and_cast_2d_fn_ptr_t)( sycl::queue &, size_t, - const std::array, - const std::array, - const std::array, + const std::array, + const std::array, + const std::array, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); /*! @@ -447,13 +446,13 @@ template sycl::event copy_and_cast_nd_specialized_impl(sycl::queue &q, size_t nelems, - const std::array shape, - const std::array src_strides, - const std::array dst_strides, + const std::array shape, + const std::array src_strides, + const std::array dst_strides, const char *src_p, - py::ssize_t src_offset, + ssize_t src_offset, char *dst_p, - py::ssize_t dst_offset, + ssize_t dst_offset, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -528,9 +527,9 @@ class GenericCopyFromHostFunctor void operator()(sycl::id<1> wiid) const { - const auto &offsets = indexer_(static_cast(wiid.get(0))); - const py::ssize_t &src_offset = offsets.get_first_offset(); - const py::ssize_t &dst_offset = offsets.get_second_offset(); + const auto &offsets = indexer_(static_cast(wiid.get(0))); + const ssize_t &src_offset = offsets.get_first_offset(); + const ssize_t &dst_offset = offsets.get_second_offset(); CastFnT fn{}; dst_[dst_offset] = fn(src_acc_[src_offset]); @@ -541,13 +540,13 @@ typedef void (*copy_and_cast_from_host_blocking_fn_ptr_t)( sycl::queue &, size_t, int, - py::ssize_t *, + ssize_t *, const char *, - py::ssize_t, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &, const std::vector &); @@ -594,17 +593,17 @@ void copy_and_cast_from_host_impl( sycl::queue &q, size_t nelems, int nd, - py::ssize_t *shape_and_strides, + ssize_t *shape_and_strides, const char *host_src_p, - py::ssize_t src_offset, - py::ssize_t src_min_nelem_offset, - py::ssize_t src_max_nelem_offset, + ssize_t src_offset, + ssize_t src_min_nelem_offset, + ssize_t src_max_nelem_offset, char *dst_p, - py::ssize_t dst_offset, + ssize_t dst_offset, const std::vector &depends, const std::vector &additional_depends) { - py::ssize_t nelems_range = src_max_nelem_offset - src_min_nelem_offset + 1; + ssize_t nelems_range = src_max_nelem_offset - src_min_nelem_offset + 1; dpctl::tensor::type_utils::validate_type_for_device(q); dpctl::tensor::type_utils::validate_type_for_device(q); @@ -621,7 +620,7 @@ void copy_and_cast_from_host_impl( TwoOffsets_StridedIndexer indexer{ nd, src_offset - src_min_nelem_offset, dst_offset, - const_cast(shape_and_strides)}; + const_cast(shape_and_strides)}; dstTy *dst_tp = reinterpret_cast(dst_p); @@ -683,8 +682,8 @@ class GenericCopyForReshapeFunctor void operator()(sycl::id<1> wiid) const { - const py::ssize_t src_offset = src_indexer_(wiid.get(0)); - const py::ssize_t dst_offset = dst_indexer_(wiid.get(0)); + const ssize_t src_offset = src_indexer_(wiid.get(0)); + const ssize_t dst_offset = dst_indexer_(wiid.get(0)); dst_p[dst_offset] = src_p[src_offset]; } @@ -693,12 +692,12 @@ class GenericCopyForReshapeFunctor // define function type typedef sycl::event (*copy_for_reshape_fn_ptr_t)( sycl::queue &, - size_t, // num_elements - int, // src_nd - int, // dst_nd - py::ssize_t *, // packed shapes and strides - const char *, // src_data_ptr - char *, // dst_data_ptr + size_t, // num_elements + int, // src_nd + int, // dst_nd + ssize_t *, // packed shapes and strides + const char *, // src_data_ptr + char *, // dst_data_ptr const std::vector &); /*! @@ -728,7 +727,7 @@ copy_for_reshape_generic_impl(sycl::queue &q, size_t nelems, int src_nd, int dst_nd, - py::ssize_t *packed_shapes_and_strides, + ssize_t *packed_shapes_and_strides, const char *src_p, char *dst_p, const std::vector &depends) @@ -742,12 +741,11 @@ copy_for_reshape_generic_impl(sycl::queue &q, // USM array of size 2*(src_nd + dst_nd) // [ src_shape; src_strides; dst_shape; dst_strides ] - const py::ssize_t *src_shape_and_strides = - const_cast(packed_shapes_and_strides); + const ssize_t *src_shape_and_strides = + const_cast(packed_shapes_and_strides); - const py::ssize_t *dst_shape_and_strides = - const_cast(packed_shapes_and_strides + - (2 * src_nd)); + const ssize_t *dst_shape_and_strides = const_cast( + packed_shapes_and_strides + (2 * src_nd)); StridedIndexer src_indexer{src_nd, 0, src_shape_and_strides}; StridedIndexer dst_indexer{dst_nd, 0, dst_shape_and_strides}; @@ -820,35 +818,34 @@ template struct CompositionIndexer struct RolledNDIndexer { RolledNDIndexer(int nd, - const py::ssize_t *shape, - const py::ssize_t *strides, - const py::ssize_t *ndshifts, - py::ssize_t starting_offset) + const ssize_t *shape, + const ssize_t *strides, + const ssize_t *ndshifts, + ssize_t starting_offset) : nd_(nd), shape_(shape), strides_(strides), ndshifts_(ndshifts), starting_offset_(starting_offset) { } - py::ssize_t operator()(size_t gid) const + ssize_t operator()(size_t gid) const { return compute_offset(gid); } private: int nd_ = -1; - const py::ssize_t *shape_ = nullptr; - const py::ssize_t *strides_ = nullptr; - const py::ssize_t *ndshifts_ = nullptr; - py::ssize_t starting_offset_ = 0; + const ssize_t *shape_ = nullptr; + const ssize_t *strides_ = nullptr; + const ssize_t *ndshifts_ = nullptr; + ssize_t starting_offset_ = 0; - py::ssize_t compute_offset(py::ssize_t gid) const + ssize_t compute_offset(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd_); - py::ssize_t relative_offset_(0); - _ind.get_left_rolled_displacement( + ssize_t relative_offset_(0); + _ind.get_left_rolled_displacement( gid, shape_, // shape ptr strides_, // strides ptr @@ -884,8 +881,8 @@ class StridedCopyForRollFunctor { const size_t gid = wiid.get(0); - const py::ssize_t src_offset = src_indexer_(gid); - const py::ssize_t dst_offset = dst_indexer_(gid); + const ssize_t src_offset = src_indexer_(gid); + const ssize_t dst_offset = dst_indexer_(gid); dst_p[dst_offset] = src_p[src_offset]; } @@ -894,14 +891,14 @@ class StridedCopyForRollFunctor // define function type typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)( sycl::queue &, - size_t, // shift - size_t, // num_elements - int, // common_nd - const py::ssize_t *, // packed shapes and strides - const char *, // src_data_ptr - py::ssize_t, // src_offset - char *, // dst_data_ptr - py::ssize_t, // dst_offset + size_t, // shift + size_t, // num_elements + int, // common_nd + const ssize_t *, // packed shapes and strides + const char *, // src_data_ptr + ssize_t, // src_offset + char *, // dst_data_ptr + ssize_t, // dst_offset const std::vector &); /*! @@ -929,17 +926,16 @@ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)( * @ingroup CopyAndCastKernels */ template -sycl::event -copy_for_roll_strided_impl(sycl::queue &q, - size_t shift, - size_t nelems, - int nd, - const py::ssize_t *packed_shapes_and_strides, - const char *src_p, - py::ssize_t src_offset, - char *dst_p, - py::ssize_t dst_offset, - const std::vector &depends) +sycl::event copy_for_roll_strided_impl(sycl::queue &q, + size_t shift, + size_t nelems, + int nd, + const ssize_t *packed_shapes_and_strides, + const char *src_p, + ssize_t src_offset, + char *dst_p, + ssize_t dst_offset, + const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -985,9 +981,9 @@ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t)( size_t, // shift size_t, // num_elements const char *, // src_data_ptr - py::ssize_t, // src_offset + ssize_t, // src_offset char *, // dst_data_ptr - py::ssize_t, // dst_offset + ssize_t, // dst_offset const std::vector &); template class copy_for_roll_contig_kernel; @@ -1018,9 +1014,9 @@ sycl::event copy_for_roll_contig_impl(sycl::queue &q, size_t shift, size_t nelems, const char *src_p, - py::ssize_t src_offset, + ssize_t src_offset, char *dst_p, - py::ssize_t dst_offset, + ssize_t dst_offset, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -1085,13 +1081,13 @@ class copy_for_roll_ndshift_strided_kernel; // define function type typedef sycl::event (*copy_for_roll_ndshift_strided_fn_ptr_t)( sycl::queue &, - size_t, // num_elements - int, // common_nd - const py::ssize_t *, // packed shape, strides, shifts - const char *, // src_data_ptr - py::ssize_t, // src_offset - char *, // dst_data_ptr - py::ssize_t, // dst_offset + size_t, // num_elements + int, // common_nd + const ssize_t *, // packed shape, strides, shifts + const char *, // src_data_ptr + ssize_t, // src_offset + char *, // dst_data_ptr + ssize_t, // dst_offset const std::vector &); template @@ -1099,11 +1095,11 @@ sycl::event copy_for_roll_ndshift_strided_impl( sycl::queue &q, size_t nelems, int nd, - const py::ssize_t *packed_shapes_and_strides_and_shifts, + const ssize_t *packed_shapes_and_strides_and_shifts, const char *src_p, - py::ssize_t src_offset, + ssize_t src_offset, char *dst_p, - py::ssize_t dst_offset, + ssize_t dst_offset, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -1115,12 +1111,12 @@ sycl::event copy_for_roll_ndshift_strided_impl( // USM array of size 4 * nd // [ common_shape; src_strides; dst_strides; shifts ] - const py::ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts; - const py::ssize_t *src_strides_ptr = + const ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts; + const ssize_t *src_strides_ptr = packed_shapes_and_strides_and_shifts + nd; - const py::ssize_t *dst_strides_ptr = + const ssize_t *dst_strides_ptr = packed_shapes_and_strides_and_shifts + 2 * nd; - const py::ssize_t *shifts_ptr = + const ssize_t *shifts_ptr = packed_shapes_and_strides_and_shifts + 3 * nd; RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr, diff --git a/dpctl/tensor/libtensor/include/kernels/dpctl_tensor_types.hpp b/dpctl/tensor/libtensor/include/kernels/dpctl_tensor_types.hpp new file mode 100644 index 0000000000..c88d838abf --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/dpctl_tensor_types.hpp @@ -0,0 +1,37 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +namespace dpctl +{ +namespace tensor +{ + +typedef std::ptrdiff_t ssize_t; + +} +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 9e13648163..591f9cb24f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace abs { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -214,11 +213,11 @@ template sycl::event abs_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index cf6875c341..236999404e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace acos { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -219,11 +218,11 @@ sycl::event acos_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index a6ffa805d7..76d28ae92b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace acosh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -241,11 +240,11 @@ sycl::event acosh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index aae69d98ea..77bb3c4d67 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -31,12 +31,12 @@ #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace add { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -218,11 +217,11 @@ template sycl::event add_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -264,13 +263,13 @@ template sycl::event add_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -314,12 +313,12 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl( size_t n0, size_t n1, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] + vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl< @@ -363,12 +362,12 @@ sycl::event add_contig_row_contig_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] + vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return add_contig_matrix_contig_row_broadcast_impl( @@ -456,9 +455,9 @@ sycl::event add_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -490,11 +489,11 @@ sycl::event add_inplace_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -538,9 +537,9 @@ sycl::event add_inplace_row_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_row_matrix_broadcast_impl< diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 2759974b93..75512d80b8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace angle { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -151,11 +150,11 @@ sycl::event angle_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index dc5f2c2b18..0e27841d1e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace asin { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -243,11 +242,11 @@ sycl::event asin_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 6d712165a9..b774de27da 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace asinh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -217,11 +216,11 @@ sycl::event asinh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 93c9a6696d..c71498c196 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace atan { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -219,11 +218,11 @@ sycl::event atan_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp index ac8c0483c4..012eaa7ce4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp @@ -30,11 +30,11 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace atan2 { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -114,11 +113,11 @@ template sycl::event atan2_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -163,13 +162,13 @@ sycl::event atan2_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 4a26cd92b4..d227047c51 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace atanh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -212,11 +211,11 @@ sycl::event atanh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp index e4da56cd9e..2e3647ec9c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -29,12 +29,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace bitwise_and { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -173,11 +172,11 @@ sycl::event bitwise_and_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -224,13 +223,13 @@ sycl::event bitwise_and_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -328,9 +327,9 @@ sycl::event bitwise_and_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -366,11 +365,11 @@ sycl::event bitwise_and_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp index cc629594b9..434089a3f0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp @@ -31,10 +31,10 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -46,7 +46,6 @@ namespace kernels namespace bitwise_invert { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -178,11 +177,11 @@ sycl::event bitwise_invert_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp index 58ef64e16a..3748034098 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -30,12 +30,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace bitwise_left_shift { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -182,11 +181,11 @@ sycl::event bitwise_left_shift_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -235,13 +234,13 @@ sycl::event bitwise_left_shift_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -345,9 +344,9 @@ sycl::event bitwise_left_shift_inplace_contig_impl( sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -383,11 +382,11 @@ sycl::event bitwise_left_shift_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp index afd24216b2..b4738f7d5a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -29,12 +29,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace bitwise_or { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -171,11 +170,11 @@ template sycl::event bitwise_or_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -221,13 +220,13 @@ sycl::event bitwise_or_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -324,9 +323,9 @@ sycl::event bitwise_or_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -362,11 +361,11 @@ sycl::event bitwise_or_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp index f1989b8f64..c336d949b6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -30,12 +30,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace bitwise_right_shift { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -184,11 +183,11 @@ sycl::event bitwise_right_shift_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -237,13 +236,13 @@ sycl::event bitwise_right_shift_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -349,9 +348,9 @@ sycl::event bitwise_right_shift_inplace_contig_impl( sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -387,11 +386,11 @@ sycl::event bitwise_right_shift_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp index 7b777528c2..66d1119d79 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -29,12 +29,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace bitwise_xor { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -173,11 +172,11 @@ sycl::event bitwise_xor_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -224,13 +223,13 @@ sycl::event bitwise_xor_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -328,9 +327,9 @@ sycl::event bitwise_xor_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -366,11 +365,11 @@ sycl::event bitwise_xor_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp index 21b8f79b81..a51d778490 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace cbrt { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; template struct CbrtFunctor @@ -142,11 +141,11 @@ sycl::event cbrt_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 3672de4d9c..b7b45c4877 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -31,10 +31,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace ceil { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -160,11 +159,11 @@ sycl::event ceil_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index 1794dbf721..a3e8185276 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -25,11 +25,11 @@ #pragma once #include #include -#include #include #include #include "kernels/alignment.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" namespace dpctl @@ -41,8 +41,6 @@ namespace kernels namespace elementwise_common { -namespace py = pybind11; - using dpctl::tensor::kernels::alignment_utils:: disabled_sg_loadstore_wrapper_krn; using dpctl::tensor::kernels::alignment_utils::is_aligned; @@ -264,8 +262,8 @@ struct UnaryStridedFunctor void operator()(sycl::id<1> wid) const { const auto &offsets_ = inp_out_indexer_(wid.get(0)); - const py::ssize_t &inp_offset = offsets_.get_first_offset(); - const py::ssize_t &res_offset = offsets_.get_second_offset(); + const ssize_t &inp_offset = offsets_.get_first_offset(); + const ssize_t &res_offset = offsets_.get_second_offset(); UnaryOpT op{}; @@ -342,11 +340,11 @@ sycl::event unary_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -533,7 +531,7 @@ struct BinaryStridedFunctor void operator()(sycl::id<1> wid) const { const auto &three_offsets_ = - three_offsets_indexer_(static_cast(wid.get(0))); + three_offsets_indexer_(static_cast(wid.get(0))); const auto &inp1_offset = three_offsets_.get_first_offset(); const auto &inp2_offset = three_offsets_.get_second_offset(); @@ -685,11 +683,11 @@ typedef sycl::event (*unary_strided_impl_fn_ptr_t)( sycl::queue &, size_t, int, - const py::ssize_t *, + const ssize_t *, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &, const std::vector &); @@ -697,24 +695,24 @@ typedef sycl::event (*binary_contig_impl_fn_ptr_t)( sycl::queue &, size_t, const char *, - py::ssize_t, + ssize_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); typedef sycl::event (*binary_strided_impl_fn_ptr_t)( sycl::queue &, size_t, int, - const py::ssize_t *, + const ssize_t *, const char *, - py::ssize_t, + ssize_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &, const std::vector &); @@ -724,11 +722,11 @@ typedef sycl::event (*binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( size_t, size_t, const char *, - py::ssize_t, + ssize_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); typedef sycl::event (*binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( @@ -737,11 +735,11 @@ typedef sycl::event (*binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( size_t, size_t, const char *, - py::ssize_t, + ssize_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); template &depends = {}) { sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -831,13 +829,13 @@ sycl::event binary_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -877,12 +875,12 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl( size_t n0, size_t n1, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = op(mat[i,j], vec[j]) - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { const argT1 *mat = reinterpret_cast(mat_p) + mat_offset; @@ -955,12 +953,12 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = op(vec[j], mat[i,j]) - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { const argT1 *vec = reinterpret_cast(vec_p) + vec_offset; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp index deaef5522f..86dc2ec60a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp @@ -26,10 +26,10 @@ #pragma once #include #include -#include #include #include "kernels/alignment.hpp" +#include "kernels/dpctl_tensor_types.hpp" namespace dpctl { @@ -190,7 +190,7 @@ struct BinaryInplaceStridedFunctor void operator()(sycl::id<1> wid) const { const auto &two_offsets_ = - two_offsets_indexer_(static_cast(wid.get(0))); + two_offsets_indexer_(static_cast(wid.get(0))); const auto &inp_offset = two_offsets_.get_first_offset(); const auto &lhs_offset = two_offsets_.get_second_offset(); @@ -261,20 +261,20 @@ typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t)( sycl::queue &, size_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)( sycl::queue &, size_t, int, - const py::ssize_t *, + const ssize_t *, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &, const std::vector &); @@ -284,9 +284,9 @@ typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t)( size_t, size_t, const char *, - py::ssize_t, + ssize_t, char *, - py::ssize_t, + ssize_t, const std::vector &); template &depends = {}) { sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -360,11 +360,11 @@ sycl::event binary_inplace_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *rhs_p, - py::ssize_t rhs_offset, + ssize_t rhs_offset, char *lhs_p, - py::ssize_t lhs_offset, + ssize_t lhs_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -399,9 +399,9 @@ sycl::event binary_inplace_row_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const std::vector &depends = {}) { const argT *vec = reinterpret_cast(vec_p) + vec_offset; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index b8ebf34a23..24f00a1043 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace conj { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -171,11 +170,11 @@ sycl::event conj_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp index 00e926d3d8..77cd962f0a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp @@ -30,11 +30,11 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace copysign { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -127,11 +126,11 @@ template sycl::event copysign_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -176,13 +175,13 @@ sycl::event copysign_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index e804d5f3df..ab1b55f3cd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace cos { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -244,11 +243,11 @@ template sycl::event cos_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 16406d5547..80ac46cdf2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace cosh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -233,11 +232,11 @@ sycl::event cosh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index c354b612a7..146d20a0d7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -31,11 +31,11 @@ #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace equal { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -179,11 +178,11 @@ template sycl::event equal_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -227,13 +226,13 @@ sycl::event equal_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 66e84b69bf..99c78c19c1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace exp { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -201,11 +200,11 @@ template sycl::event exp_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index ae590fc2bf..d63693ff12 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace exp2 { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -204,11 +203,11 @@ sycl::event exp2_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 518aacfe6b..abcb51f8d3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace expm1 { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -214,11 +213,11 @@ sycl::event expm1_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index a8f810d9ac..a55fcfc565 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -31,10 +31,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace floor { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -160,11 +159,11 @@ sycl::event floor_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index d09f376d04..2395d9180a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -30,12 +30,12 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace floor_divide { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -216,11 +215,11 @@ sycl::event floor_divide_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -267,13 +266,13 @@ sycl::event floor_divide_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -410,9 +409,9 @@ sycl::event floor_divide_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -448,11 +447,11 @@ sycl::event floor_divide_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index 2ba942fb32..98bd76248a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -32,11 +32,11 @@ #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace greater { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -176,11 +175,11 @@ template sycl::event greater_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -224,13 +223,13 @@ sycl::event greater_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index 48503c608d..afa7a1bed5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -32,11 +32,11 @@ #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace greater_equal { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -178,11 +177,11 @@ sycl::event greater_equal_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -230,13 +229,13 @@ sycl::event greater_equal_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp index 64d8a8f059..e7cccaf211 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp @@ -30,11 +30,11 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace hypot { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -129,11 +128,11 @@ template sycl::event hypot_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -178,13 +177,13 @@ sycl::event hypot_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 03ed7bad78..47fcd5b6b4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace imag { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -163,11 +162,11 @@ sycl::event imag_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index b0dab6249c..17ae5cf43b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -31,9 +31,8 @@ #include #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -44,7 +43,6 @@ namespace kernels namespace isfinite { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -161,11 +159,11 @@ sycl::event isfinite_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 8c805b7934..7a3c24a553 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,10 +30,10 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -44,7 +44,6 @@ namespace kernels namespace isinf { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -158,11 +157,11 @@ sycl::event isinf_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index 8b10ce2295..1a20e38036 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,10 +29,10 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -43,7 +43,6 @@ namespace kernels namespace isnan { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -156,11 +155,11 @@ sycl::event isnan_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 1827ca3185..7e9634dba0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -29,13 +29,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace less { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -175,11 +174,11 @@ template sycl::event less_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -222,13 +221,13 @@ sycl::event less_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 0b6d06fff3..4964715da3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -32,11 +32,10 @@ #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -47,7 +46,6 @@ namespace kernels namespace less_equal { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -176,11 +174,11 @@ template sycl::event less_equal_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -225,13 +223,13 @@ sycl::event less_equal_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index 48ec92a257..d8a7c0350b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace log { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -159,11 +158,11 @@ template sycl::event log_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index 6f3c9d1925..ab53ec5b73 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace log10 { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -179,11 +178,11 @@ sycl::event log10_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index 3b417b46b9..af36ecda79 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace log1p { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -180,11 +179,11 @@ sycl::event log1p_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 079c5bf94b..1c1d274b47 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace log2 { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -179,11 +178,11 @@ sycl::event log2_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 9fb3759779..aee0ae6b7f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -31,13 +31,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace logaddexp { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -145,11 +144,11 @@ template sycl::event logaddexp_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -195,13 +194,13 @@ sycl::event logaddexp_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp index 135c264751..60e0f133c2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp @@ -30,12 +30,12 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace logical_and { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -171,11 +170,11 @@ sycl::event logical_and_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -222,13 +221,13 @@ sycl::event logical_and_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp index 8a820c5172..959b5aab01 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp @@ -30,10 +30,10 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -44,7 +44,6 @@ namespace kernels namespace logical_not { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -134,11 +133,11 @@ sycl::event logical_not_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp index a444ced41f..f3ca6cd4d5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp @@ -30,12 +30,12 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace logical_or { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -169,11 +168,11 @@ template sycl::event logical_or_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -219,13 +218,13 @@ sycl::event logical_or_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp index a3d175f413..0ee26837be 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp @@ -30,12 +30,12 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace logical_xor { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -172,11 +171,11 @@ sycl::event logical_xor_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -223,13 +222,13 @@ sycl::event logical_xor_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index 22a63882a9..da32fb6f7b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -29,13 +29,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace maximum { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -193,11 +192,11 @@ template sycl::event maximum_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -242,13 +241,13 @@ sycl::event maximum_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index a11a36aeee..c6e5e841c2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -29,13 +29,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace minimum { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -193,11 +192,11 @@ template sycl::event minimum_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -242,13 +241,13 @@ sycl::event minimum_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index 7aeea6fdc7..7e7dd13c1c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -30,14 +30,14 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace multiply { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -198,11 +197,11 @@ template sycl::event multiply_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -248,13 +247,13 @@ sycl::event multiply_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -301,12 +300,12 @@ sycl::event multiply_contig_matrix_contig_row_broadcast_impl( size_t n0, size_t n1, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] * vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl< @@ -351,12 +350,12 @@ sycl::event multiply_contig_row_contig_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] * vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return multiply_contig_matrix_contig_row_broadcast_impl( @@ -446,9 +445,9 @@ sycl::event multiply_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -484,11 +483,11 @@ sycl::event multiply_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -535,9 +534,9 @@ sycl::event multiply_inplace_row_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_row_matrix_broadcast_impl< diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp index b67e74438f..2a51c0bbb4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace negative { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -154,11 +153,11 @@ sycl::event negative_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index a5bc3a6cc6..c119289690 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -29,12 +29,12 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace not_equal { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -179,11 +178,11 @@ template sycl::event not_equal_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -228,13 +227,13 @@ sycl::event not_equal_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp index 6136a55bce..8ee09d409a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace positive { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -169,11 +168,11 @@ sycl::event positive_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index 65214c9533..6db8d9fcb7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -30,14 +30,14 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace pow { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -251,11 +250,11 @@ template sycl::event pow_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -298,13 +297,13 @@ template sycl::event pow_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -458,9 +457,9 @@ sycl::event pow_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -492,11 +491,11 @@ sycl::event pow_inplace_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index 2fe1d70bd5..c4553310d4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace proj { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -194,11 +193,11 @@ sycl::event proj_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 94cdaf1496..e2418553c7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace real { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -163,11 +162,11 @@ sycl::event real_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index ecc5959fc3..7e1e27f5ea 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -32,13 +32,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" -#include namespace dpctl { @@ -49,7 +49,6 @@ namespace kernels namespace reciprocal { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -167,11 +166,11 @@ sycl::event reciprocal_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp index b25c7be91f..5b1c6cc815 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp @@ -30,13 +30,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace remainder { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -232,11 +231,11 @@ template sycl::event remainder_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -281,13 +280,13 @@ sycl::event remainder_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -434,9 +433,9 @@ sycl::event remainder_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -472,11 +471,11 @@ sycl::event remainder_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index e8221e4c25..c0340ef13a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -31,10 +31,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace round { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -171,11 +170,11 @@ sycl::event round_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp index a7f70337f8..9d4a28fe52 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp @@ -35,10 +35,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -49,7 +49,6 @@ namespace kernels namespace rsqrt { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; template struct RsqrtFunctor @@ -145,11 +144,11 @@ sycl::event rsqrt_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index 804b6260f1..2cc6887b1b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace sign { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -192,11 +191,11 @@ sycl::event sign_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp index 401db90b63..a0f474c293 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp @@ -30,10 +30,10 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -44,7 +44,6 @@ namespace kernels namespace signbit { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -151,11 +150,11 @@ sycl::event signbit_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index debdfa5fca..37b718f7b4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace sin { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -265,11 +264,11 @@ template sycl::event sin_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index cd7f998afa..0883a6dcc0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -32,10 +32,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace sinh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -235,11 +234,11 @@ sycl::event sinh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 3eea801fb2..970e215591 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -35,10 +35,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -49,7 +49,6 @@ namespace kernels namespace sqrt { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -343,11 +342,11 @@ sycl::event sqrt_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index 656ec47c4b..b6650b65e7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace square { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -186,11 +185,11 @@ sycl::event square_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index a24e71e9d2..544d91c02b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,13 +29,13 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -46,7 +46,6 @@ namespace kernels namespace subtract { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -180,11 +179,11 @@ template sycl::event subtract_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -229,13 +228,13 @@ sycl::event subtract_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -293,12 +292,12 @@ sycl::event subtract_contig_matrix_contig_row_broadcast_impl( size_t n0, size_t n1, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] - vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl< @@ -346,12 +345,12 @@ sycl::event subtract_contig_row_contig_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = op(vec[j], mat[i,j]) - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_row_contig_matrix_broadcast_impl< @@ -443,9 +442,9 @@ sycl::event subtract_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -481,11 +480,11 @@ sycl::event subtract_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -532,9 +531,9 @@ sycl::event subtract_inplace_row_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_row_matrix_broadcast_impl< diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index e2d08cba0d..d944e43bb5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -33,10 +33,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace tan { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -209,11 +208,11 @@ template sycl::event tan_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 13ea6c7eee..d0ee54fe8c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -34,10 +34,10 @@ #include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -48,7 +48,6 @@ namespace kernels namespace tanh { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -205,11 +204,11 @@ sycl::event tanh_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index e063ecef54..ab06a52229 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -29,14 +29,14 @@ #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" -#include namespace dpctl { @@ -47,7 +47,6 @@ namespace kernels namespace true_divide { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -200,11 +199,11 @@ sycl::event true_divide_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_impl< @@ -250,13 +249,13 @@ sycl::event true_divide_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg1_p, - py::ssize_t arg1_offset, + ssize_t arg1_offset, const char *arg2_p, - py::ssize_t arg2_offset, + ssize_t arg2_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -315,12 +314,12 @@ sycl::event true_divide_contig_matrix_contig_row_broadcast_impl( size_t n0, size_t n1, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] / vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl< @@ -368,12 +367,12 @@ sycl::event true_divide_contig_row_contig_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, // res[i,j] = mat[i,j] + vec[j] - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_contig_row_contig_matrix_broadcast_impl< @@ -541,9 +540,9 @@ sycl::event true_divide_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_contig_impl< @@ -579,11 +578,11 @@ sycl::event true_divide_inplace_strided_impl( sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { @@ -630,9 +629,9 @@ sycl::event true_divide_inplace_row_matrix_broadcast_impl( size_t n0, size_t n1, const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, + ssize_t vec_offset, char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, + ssize_t mat_offset, const std::vector &depends = {}) { return elementwise_common::binary_inplace_row_matrix_broadcast_impl< diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 35b0783719..b27792fda7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -31,10 +31,10 @@ #include "kernels/elementwise_functions/common.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -#include namespace dpctl { @@ -45,7 +45,6 @@ namespace kernels namespace trunc { -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -157,11 +156,11 @@ sycl::event trunc_strided_impl(sycl::queue &exec_q, size_t nelems, int nd, - const py::ssize_t *shape_and_strides, + const ssize_t *shape_and_strides, const char *arg_p, - py::ssize_t arg_offset, + ssize_t arg_offset, char *res_p, - py::ssize_t res_offset, + ssize_t res_offset, const std::vector &depends, const std::vector &additional_depends) { diff --git a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp index 769774f4dd..07463b118e 100644 --- a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp @@ -26,10 +26,10 @@ #include #include #include -#include #include #include +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" #include "utils/type_utils.hpp" @@ -42,7 +42,6 @@ namespace kernels namespace indexing { -namespace py = pybind11; using namespace dpctl::tensor::offset_utils; template (max_item, 1); - ind = std::clamp(ind, -max_item, max_item - 1); + max_item = std::max(max_item, 1); + ind = std::clamp(ind, -max_item, max_item - 1); ind = (ind < 0) ? ind + max_item : ind; return; } @@ -79,10 +78,10 @@ class ClipIndex public: ClipIndex() = default; - void operator()(py::ssize_t max_item, py::ssize_t &ind) const + void operator()(ssize_t max_item, ssize_t &ind) const { - max_item = std::max(max_item, 1); - ind = std::clamp(ind, 0, max_item - 1); + max_item = std::max(max_item, 1); + ind = std::clamp(ind, 0, max_item - 1); return; } }; @@ -101,7 +100,7 @@ class TakeFunctor char **ind_ = nullptr; int k_ = 0; size_t ind_nelems_ = 0; - const py::ssize_t *axes_shape_and_strides_ = nullptr; + const ssize_t *axes_shape_and_strides_ = nullptr; OrthogStrider orthog_strider; IndicesStrider ind_strider; AxesStrider axes_strider; @@ -112,7 +111,7 @@ class TakeFunctor char **ind_cp, int k, size_t ind_nelems, - const py::ssize_t *axes_shape_and_strides, + const ssize_t *axes_shape_and_strides, OrthogStrider orthog_strider_, IndicesStrider ind_strider_, AxesStrider axes_strider_) @@ -129,20 +128,20 @@ class TakeFunctor const T *src = reinterpret_cast(src_); T *dst = reinterpret_cast(dst_); - py::ssize_t i_orthog = id / ind_nelems_; - py::ssize_t i_along = id - (i_orthog * ind_nelems_); + ssize_t i_orthog = id / ind_nelems_; + ssize_t i_along = id - (i_orthog * ind_nelems_); auto orthog_offsets = orthog_strider(i_orthog); - py::ssize_t src_offset = orthog_offsets.get_first_offset(); - py::ssize_t dst_offset = orthog_offsets.get_second_offset(); + ssize_t src_offset = orthog_offsets.get_first_offset(); + ssize_t dst_offset = orthog_offsets.get_second_offset(); ProjectorT proj{}; for (int axis_idx = 0; axis_idx < k_; ++axis_idx) { indT *ind_data = reinterpret_cast(ind_[axis_idx]); - py::ssize_t ind_offset = ind_strider(i_along, axis_idx); - py::ssize_t i = static_cast(ind_data[ind_offset]); + ssize_t ind_offset = ind_strider(i_along, axis_idx); + ssize_t i = static_cast(ind_data[ind_offset]); proj(axes_shape_and_strides_[axis_idx], i); @@ -161,15 +160,15 @@ typedef sycl::event (*take_fn_ptr_t)(sycl::queue &, int, int, int, - const py::ssize_t *, - const py::ssize_t *, - const py::ssize_t *, + const ssize_t *, + const ssize_t *, + const ssize_t *, const char *, char *, char **, - py::ssize_t, - py::ssize_t, - const py::ssize_t *, + ssize_t, + ssize_t, + const ssize_t *, const std::vector &); template @@ -179,15 +178,15 @@ sycl::event take_impl(sycl::queue &q, int nd, int ind_nd, int k, - const py::ssize_t *orthog_shape_and_strides, - const py::ssize_t *axes_shape_and_strides, - const py::ssize_t *ind_shape_and_strides, + const ssize_t *orthog_shape_and_strides, + const ssize_t *axes_shape_and_strides, + const ssize_t *ind_shape_and_strides, const char *src_p, char *dst_p, char **ind_p, - py::ssize_t src_offset, - py::ssize_t dst_offset, - const py::ssize_t *ind_offsets, + ssize_t src_offset, + ssize_t dst_offset, + const ssize_t *ind_offsets, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -231,7 +230,7 @@ class PutFunctor char **ind_ = nullptr; int k_ = 0; size_t ind_nelems_ = 0; - const py::ssize_t *axes_shape_and_strides_ = nullptr; + const ssize_t *axes_shape_and_strides_ = nullptr; OrthogStrider orthog_strider; IndicesStrider ind_strider; AxesStrider axes_strider; @@ -242,7 +241,7 @@ class PutFunctor char **ind_cp, int k, size_t ind_nelems, - const py::ssize_t *axes_shape_and_strides, + const ssize_t *axes_shape_and_strides, OrthogStrider orthog_strider_, IndicesStrider ind_strider_, AxesStrider axes_strider_) @@ -259,20 +258,20 @@ class PutFunctor T *dst = reinterpret_cast(dst_); const T *val = reinterpret_cast(val_); - py::ssize_t i_orthog = id / ind_nelems_; - py::ssize_t i_along = id - (i_orthog * ind_nelems_); + ssize_t i_orthog = id / ind_nelems_; + ssize_t i_along = id - (i_orthog * ind_nelems_); auto orthog_offsets = orthog_strider(i_orthog); - py::ssize_t dst_offset = orthog_offsets.get_first_offset(); - py::ssize_t val_offset = orthog_offsets.get_second_offset(); + ssize_t dst_offset = orthog_offsets.get_first_offset(); + ssize_t val_offset = orthog_offsets.get_second_offset(); ProjectorT proj{}; for (int axis_idx = 0; axis_idx < k_; ++axis_idx) { indT *ind_data = reinterpret_cast(ind_[axis_idx]); - py::ssize_t ind_offset = ind_strider(i_along, axis_idx); - py::ssize_t i = static_cast(ind_data[ind_offset]); + ssize_t ind_offset = ind_strider(i_along, axis_idx); + ssize_t i = static_cast(ind_data[ind_offset]); proj(axes_shape_and_strides_[axis_idx], i); @@ -291,15 +290,15 @@ typedef sycl::event (*put_fn_ptr_t)(sycl::queue &, int, int, int, - const py::ssize_t *, - const py::ssize_t *, - const py::ssize_t *, + const ssize_t *, + const ssize_t *, + const ssize_t *, char *, const char *, char **, - py::ssize_t, - py::ssize_t, - const py::ssize_t *, + ssize_t, + ssize_t, + const ssize_t *, const std::vector &); template @@ -309,15 +308,15 @@ sycl::event put_impl(sycl::queue &q, int nd, int ind_nd, int k, - const py::ssize_t *orthog_shape_and_strides, - const py::ssize_t *axes_shape_and_strides, - const py::ssize_t *ind_shape_and_strides, + const ssize_t *orthog_shape_and_strides, + const ssize_t *axes_shape_and_strides, + const ssize_t *ind_shape_and_strides, char *dst_p, const char *val_p, char **ind_p, - py::ssize_t dst_offset, - py::ssize_t val_offset, - const py::ssize_t *ind_offsets, + ssize_t dst_offset, + ssize_t val_offset, + const ssize_t *ind_offsets, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp new file mode 100644 index 0000000000..039417d6a5 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -0,0 +1,1137 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/reductions.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +template +struct SequentialDotProduct +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialDotProduct(const lhsT *lhs, + const rhsT *rhs, + outT *out, + BatchIndexerT batch_indexer, + RedIndexerT reduced_dims_indexer, + size_t reduction_size) + : lhs_(lhs), rhs_(rhs), out_(out), batch_indexer_(batch_indexer), + reduced_dims_indexer_(reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &batch_offsets = batch_indexer_(id[0]); + const ssize_t &lhs_batch_offset = batch_offsets.get_first_offset(); + const ssize_t &rhs_batch_offset = batch_offsets.get_second_offset(); + const ssize_t &out_batch_offset = batch_offsets.get_third_offset(); + + outT red_val(0); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + auto reduction_offsets = reduced_dims_indexer_(m); + auto lhs_reduction_offset = reduction_offsets.get_first_offset(); + auto rhs_reduction_offset = reduction_offsets.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + red_val += convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } + + out_[out_batch_offset] = red_val; + } +}; + +template +struct DotProductFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_batch_offset]); + res_ref += red_val_over_wg; + } + } +}; + +template +class dot_product_seq_krn; + +template class dot_product_init_krn; + +template +class dot_product_krn; + +typedef sycl::event (*dot_product_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event dot_product_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const ssize_t *batch_shape_and_strides, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = batch_shape_and_strides; + const ssize_t *const &res_strides = + batch_shape_and_strides + 3 * batch_nd; + IndexerT res_indexer(batch_nd, batch_res_offset, res_shape, + res_strides); + using InitKernelName = + class dot_product_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(batches), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = 0; + }); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + return dot_ev; + } +} + +typedef sycl::event (*dot_product_contig_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event +dot_product_contig_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.fill(res_tp, resTy(0), batches); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + return dot_ev; + } +} + +template +struct DotProductNoAtomicFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductNoAtomicFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + const size_t n_reduction_groups = it.get_group_range(0) / batches_; + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_batch_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template +class dot_product_tree_krn; + +template +class dot_product_reduction_over_group_temps_krn; + +template +sycl::event dot_product_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const ssize_t *batch_shape_and_strides, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LhsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using RhsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + LhsIndexerT, RhsIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + LhsIndexerT lhs_indexer(batch_nd, batch_lhs_offset, + batch_shape_and_strides); + RhsIndexerT rhs_indexer(batch_nd, batch_rhs_offset, + batch_shape_and_strides, + batch_shape_and_strides + 2 * batch_nd); + ResIndexerT noop_tmp_indexer{}; + + InputOutputBatchIndexerT in_out_iter_indexer{ + lhs_indexer, rhs_indexer, noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{ + red_nd, reduction_lhs_offset, reduction_rhs_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{batch_nd, batch_res_offset, + /* shape */ batch_shape_and_strides, + /* strides */ batch_shape_and_strides + + 2 * batch_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +dot_product_contig_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + inp_out_batch_indexer, reduction_indexer, reduction_nelems, + batches, preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp new file mode 100644 index 0000000000..0d90917885 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -0,0 +1,6957 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/reductions.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +namespace gemm_detail +{ + +template +void scale_gemm_k_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t delta_k, + size_t &n_wi, + size_t &delta_n) +{ + constexpr size_t slm_elem_size = sizeof(T) * m_groups; + + while (slm_elem_size * (n_wi + delta_n) * delta_k + reserved_slm_size >= + local_mem_size) + { + n_wi = n_wi / 2; + delta_n = delta_n / 2; + if (delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} + +template +void scale_gemm_nm_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t &wi_delta_n, + size_t &wi_delta_k, + size_t &wg_delta_n, + size_t &wg_delta_m) +{ + constexpr size_t slm_A_elem_size = sizeof(T); + constexpr size_t slm_B_elem_size = sizeof(T) * wi_delta_m; + + while ((wi_delta_n * wg_delta_n * wi_delta_k * slm_A_elem_size) + + (wi_delta_k * wg_delta_m * slm_B_elem_size) + + reserved_slm_size >= + local_mem_size) + { + wg_delta_n /= 2; + wg_delta_m /= 2; + wi_delta_k /= 2; + if (wg_delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} +} // namespace gemm_detail + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +class gemm_seq_reduction_krn; + +template +class gemm_tree_reduction_krn; + +template +sycl::event single_reduction_for_gemm(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + ssize_t res_offset, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} + +template +sycl::event +single_reduction_for_gemm_contig(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} + +template +sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + ssize_t res_offset, + const ssize_t *res_shape_strides, + const std::vector &depends) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{res_nd, static_cast(res_offset), + res_shape_strides}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class gemm_reduction_over_group_temps_contig_krn; + +template +sycl::event +tree_reduction_for_gemm_contig(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + const std::vector &depends) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + // n * m = iter_nelems because essentially, this process + // creates a stack of reduction_nelems 2D matrices and we reduce + // along the stack axis + InputIndexerT inp_indexer{0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +// specialization for wi_delta_m == 1 +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1(res[res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } + } + } + } +}; + +// specialization for m_groups == 1 +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_init_krn; + +template +class gemm_k_krn; + +template +class gemm_nm_krn; + +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // lhs_outer_nelems (n) + size_t, // inner_nelems (k) + size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const ssize_t *, // lhs shape and strides + int, // rhs outer nd + const ssize_t *, // rhs shape and strides + int, // res outer nd + const ssize_t *, // res shape and strides + std::vector const &); + +template +sycl::event gemm_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const ssize_t *rhs_shape_strides, + int res_outer_nd, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + using InitKernelName = class gemm_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // n + size_t, // k + size_t, // m + std::vector const &); + +template +sycl::event gemm_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + + if (k == 0) { + return res_init_ev; + } + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerIndexerT lhs_indexer{}; + OuterInnerIndexerT rhs_indexer{}; + OuterInnerIndexerT res_indexer{}; + + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + + block_s * n * m)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_indexer(gl_i * c_st0 + j * c_st1 + block_s * n * m)] = + local_sum; + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + const size_t res_offset = (block_s * n * m); + res[res_indexer(i * m + j) + res_offset] = local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[res_indexer(i * m + j + vec_id) + res_offset] = + local_sum[vec_id]; + } + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum; + } + } +}; + +template +class gemm_tree_nm_krn; + +template +class gemm_tree_k_krn; + +template +sycl::event gemm_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template class gemm_tree_empty_krn; + +template +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_nd, 0, res_shapes_strides); + using InitKernelName = + class gemm_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } +} + +template +sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + return gemm_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } +} + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_offset + res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1( + res[res_offset + res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + ; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_offset + res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_batch_init_krn; + +template +class gemm_batch_k_krn; + +template +class gemm_batch_nm_krn; + +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // lhs outer nelems (n) + size_t, // inner nelems (k) + size_t, // rhs outer nelems (m) + int, // batching nd + const ssize_t *, // batch shape strides + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + int, // inner dims + int, // lhs outer dims + const ssize_t *, // lhs outer and inner shape and strides + int, // rhs outer dims + const ssize_t *, // rhs outer and inner shape and strides + int, // res outer dims + const ssize_t *, // res outer and inner shape and strides + const ssize_t *, // res full shape and strides + std::vector const &); + +template +sycl::event gemm_batch_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = class gemm_batch_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // n + size_t, // k + size_t, // m + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + std::vector const &); + +template +sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + + if (k == 0) { + return res_init_ev; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + const size_t total_offset = + res_offset + (block_s * n * m * batch_nelems); + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; + +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[1]; + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && ((t + t_shift < k))) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_offset + res_indexer(i * m + j) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } +}; + +template +class gemm_batch_tree_k_krn; + +template +class gemm_batch_tree_nm_krn; + +template +sycl::event +gemm_batch_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + StridedIndexer rhs_batch_indexer(batch_nd, rhs_batch_offset, + batch_shape_strides + + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +class gemm_batch_tree_empty_krn; + +template +sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = + class gemm_batch_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_batch_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } +} + +template +sycl::event +gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event +gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + return gemm_batch_no_reduction_ev; + } + + if ((k > n && k > m) || m < 4) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index a8d9cf1972..50babfdbe0 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -31,14 +31,13 @@ #include #include -#include "pybind11/pybind11.h" +#include "dpctl_tensor_types.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_utils.hpp" -#include "utils/type_dispatch.hpp" +#include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" -namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace su_ns = dpctl::tensor::sycl_utils; @@ -98,17 +97,15 @@ struct SequentialReduction { auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); - const py::ssize_t &inp_iter_offset = + const ssize_t &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const py::ssize_t &out_iter_offset = + const ssize_t &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); outT red_val(identity_); for (size_t m = 0; m < reduction_max_gid_; ++m) { - const py::ssize_t inp_reduction_offset = - inp_reduced_dims_indexer_(m); - const py::ssize_t inp_offset = - inp_iter_offset + inp_reduction_offset; + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; using dpctl::tensor::type_utils::convert_impl; outT val = convert_impl(inp_[inp_offset]); @@ -334,12 +331,12 @@ typedef sycl::event (*reduction_strided_impl_fn_ptr)( const char *, char *, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, int, - const py::ssize_t *, - py::ssize_t, + const ssize_t *, + ssize_t, const std::vector &); template @@ -396,12 +393,12 @@ sycl::event reduction_over_group_with_atomics_strided_impl( const char *arg_cp, char *res_cp, int iter_nd, - const py::ssize_t *iter_shape_and_strides, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, int red_nd, - const py::ssize_t *reduction_shape_stride, - py::ssize_t reduction_arg_offset, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp); @@ -445,8 +442,8 @@ sycl::event reduction_over_group_with_atomics_strided_impl( using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; - const py::ssize_t *const &res_shape = iter_shape_and_strides; - const py::ssize_t *const &res_strides = + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = iter_shape_and_strides + 2 * iter_nd; IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); @@ -536,9 +533,9 @@ typedef sycl::event (*reduction_contig_impl_fn_ptr)( size_t, const char *, char *, - py::ssize_t, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); /* @brief Reduce rows in a matrix */ @@ -551,9 +548,9 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t reduction_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -579,8 +576,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -610,9 +607,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl( RowsIndexerT, NoOpIndexerT>; using ReductionIndexerT = NoOpIndexerT; - RowsIndexerT rows_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_nelems)}; + RowsIndexerT rows_indexer{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}; NoOpIndexerT result_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, result_indexer}; @@ -680,9 +676,9 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( // number of rows) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t reduction_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -709,8 +705,8 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl( InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; using KernelName = class reduction_seq_contig_krn(reduction_nelems), - /* step */ static_cast(iter_nelems)}; + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; constexpr size_t preferred_reductions_per_wi = 8; size_t reductions_per_wi = @@ -989,12 +985,12 @@ typedef sycl::event (*reduction_strided_impl_fn_ptr)( const char *, char *, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, int, - const py::ssize_t *, - py::ssize_t, + const ssize_t *, + ssize_t, const std::vector &); template @@ -1109,12 +1105,12 @@ sycl::event reduction_over_group_temps_strided_impl( const char *arg_cp, char *res_cp, int iter_nd, - const py::ssize_t *iter_shape_and_strides, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, int red_nd, - const py::ssize_t *reduction_shape_stride, - py::ssize_t reduction_arg_offset, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp); @@ -1127,8 +1123,8 @@ sycl::event reduction_over_group_temps_strided_impl( using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; - const py::ssize_t *const &res_shape = iter_shape_and_strides; - const py::ssize_t *const &res_strides = + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = iter_shape_and_strides + 2 * iter_nd; IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); @@ -1369,8 +1365,8 @@ sycl::event reduction_over_group_temps_strided_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -1433,8 +1429,8 @@ sycl::event reduction_over_group_temps_strided_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, /* shape */ iter_shape_and_strides, /* strides */ iter_shape_and_strides + @@ -1517,9 +1513,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t reduction_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -1552,8 +1548,8 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -1592,8 +1588,8 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -1684,9 +1680,8 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( RowsIndexerT, NoOpIndexerT>; using ReductionIndexerT = NoOpIndexerT; - RowsIndexerT rows_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_nelems)}; + RowsIndexerT rows_indexer{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}; NoOpIndexerT noop_tmp_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, noop_tmp_indexer}; @@ -1758,8 +1753,8 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -1821,8 +1816,8 @@ sycl::event reduction_axis1_over_group_temps_contig_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -1902,9 +1897,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t reduction_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -1938,8 +1933,8 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; using KernelName = class reduction_seq_contig_krn(reduction_nelems), - /* step */ static_cast(iter_nelems)}; + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; if (iter_nelems == 1) { // increase GPU occupancy @@ -2079,8 +2074,8 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, noop_tmp_indexer}; ReductionIndexerT reduction_indexer{ - 0, /* size */ static_cast(reduction_nelems), - /* step */ static_cast(iter_nelems)}; + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; @@ -2148,8 +2143,8 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -2211,8 +2206,8 @@ sycl::event reduction_axis0_over_group_temps_contig_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -3522,18 +3517,16 @@ struct SequentialSearchReduction { auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); - const py::ssize_t &inp_iter_offset = + const ssize_t &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); - const py::ssize_t &out_iter_offset = + const ssize_t &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); argT red_val(identity_); outT idx_val(idx_identity_); for (size_t m = 0; m < reduction_max_gid_; ++m) { - const py::ssize_t inp_reduction_offset = - inp_reduced_dims_indexer_(m); - const py::ssize_t inp_offset = - inp_iter_offset + inp_reduction_offset; + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; argT val = inp_[inp_offset]; if (val == red_val) { @@ -3997,12 +3990,12 @@ typedef sycl::event (*search_strided_impl_fn_ptr)( const char *, char *, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, int, - const py::ssize_t *, - py::ssize_t, + const ssize_t *, + ssize_t, const std::vector &); template &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp); @@ -4244,8 +4237,8 @@ sycl::event search_over_group_temps_strided_impl( using IndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; - const py::ssize_t *const &res_shape = iter_shape_and_strides; - const py::ssize_t *const &res_strides = + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = iter_shape_and_strides + 2 * iter_nd; IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); @@ -4506,8 +4499,8 @@ sycl::event search_over_group_temps_strided_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -4577,8 +4570,8 @@ sycl::event search_over_group_temps_strided_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{iter_nd, iter_res_offset, /* shape */ iter_shape_and_strides, /* strides */ iter_shape_and_strides + @@ -4664,9 +4657,9 @@ typedef sycl::event (*search_contig_impl_fn_ptr)( size_t, const char *, char *, - py::ssize_t, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -4717,8 +4710,8 @@ sycl::event search_axis1_over_group_temps_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -4759,8 +4752,8 @@ sycl::event search_axis1_over_group_temps_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -4865,8 +4858,8 @@ sycl::event search_axis1_over_group_temps_contig_impl( using ReductionIndexerT = NoOpIndexerT; InputOutputIterIndexerT in_out_iter_indexer{ - InputIterIndexerT{0, static_cast(iter_nelems), - static_cast(reduction_nelems)}, + InputIterIndexerT{0, static_cast(iter_nelems), + static_cast(reduction_nelems)}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{}; @@ -4943,8 +4936,8 @@ sycl::event search_axis1_over_group_temps_contig_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -5013,8 +5006,8 @@ sycl::event search_axis1_over_group_temps_contig_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -5103,9 +5096,9 @@ sycl::event search_axis0_over_group_temps_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t reduction_arg_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -5140,8 +5133,8 @@ sycl::event search_axis0_over_group_temps_contig_impl( InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; using KernelName = class search_seq_contig_krn(reduction_nelems), - /* step */ static_cast(iter_nelems)}; + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; if (iter_nelems == 1) { // increase GPU occupancy @@ -5295,8 +5288,8 @@ sycl::event search_axis0_over_group_temps_contig_impl( InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, result_indexer}; ReductionIndexerT reduction_indexer{ - 0, /* size */ static_cast(reduction_nelems), - /* step */ static_cast(iter_nelems)}; + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; @@ -5371,8 +5364,8 @@ sycl::event search_axis0_over_group_temps_contig_impl( dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, @@ -5441,8 +5434,8 @@ sycl::event search_axis0_over_group_temps_contig_impl( using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(remaining_reduction_nelems)}; + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; ResIndexerT res_iter_indexer{}; InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, diff --git a/dpctl/tensor/libtensor/include/kernels/repeat.hpp b/dpctl/tensor/libtensor/include/kernels/repeat.hpp index 05b57a8cda..66601329ae 100644 --- a/dpctl/tensor/libtensor/include/kernels/repeat.hpp +++ b/dpctl/tensor/libtensor/include/kernels/repeat.hpp @@ -26,10 +26,10 @@ #include #include #include -#include #include #include +#include "dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" #include "utils/type_utils.hpp" @@ -42,7 +42,6 @@ namespace kernels namespace repeat { -namespace py = pybind11; using namespace dpctl::tensor::offset_utils; template &); template @@ -138,15 +137,15 @@ repeat_by_sequence_impl(sycl::queue &q, const char *reps_cp, const char *cumsum_cp, int orthog_nd, - const py::ssize_t *orthog_src_dst_shape_and_strides, - py::ssize_t src_offset, - py::ssize_t dst_offset, - py::ssize_t src_axis_shape, - py::ssize_t src_axis_stride, - py::ssize_t dst_axis_shape, - py::ssize_t dst_axis_stride, - py::ssize_t reps_shape, - py::ssize_t reps_stride, + const ssize_t *orthog_src_dst_shape_and_strides, + ssize_t src_offset, + ssize_t dst_offset, + ssize_t src_axis_shape, + ssize_t src_axis_stride, + ssize_t dst_axis_shape, + ssize_t dst_axis_stride, + ssize_t reps_shape, + ssize_t reps_stride, const std::vector &depends) { sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { @@ -200,11 +199,11 @@ typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)( const char *, const char *, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template @@ -215,11 +214,11 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q, const char *reps_cp, const char *cumsum_cp, int src_nd, - const py::ssize_t *src_shape_strides, - py::ssize_t dst_shape, - py::ssize_t dst_stride, - py::ssize_t reps_shape, - py::ssize_t reps_stride, + const ssize_t *src_shape_strides, + ssize_t dst_shape, + ssize_t dst_stride, + ssize_t reps_shape, + ssize_t reps_stride, const std::vector &depends) { sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { @@ -277,7 +276,7 @@ class RepeatScalarFunctor private: const T *src = nullptr; T *dst = nullptr; - const py::ssize_t reps = 1; + const ssize_t reps = 1; size_t dst_axis_nelems = 0; OrthogIndexer orthog_strider; SrcAxisIndexer src_axis_strider; @@ -286,7 +285,7 @@ class RepeatScalarFunctor public: RepeatScalarFunctor(const T *src_, T *dst_, - const py::ssize_t reps_, + const ssize_t reps_, size_t dst_axis_nelems_, OrthogIndexer orthog_strider_, SrcAxisIndexer src_axis_strider_, @@ -319,15 +318,15 @@ typedef sycl::event (*repeat_by_scalar_fn_ptr_t)( size_t, const char *, char *, - const py::ssize_t, + const ssize_t, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template @@ -336,15 +335,15 @@ sycl::event repeat_by_scalar_impl(sycl::queue &q, size_t dst_axis_nelems, const char *src_cp, char *dst_cp, - const py::ssize_t reps, + const ssize_t reps, int orthog_nd, - const py::ssize_t *orthog_shape_and_strides, - py::ssize_t src_offset, - py::ssize_t dst_offset, - py::ssize_t src_axis_shape, - py::ssize_t src_axis_stride, - py::ssize_t dst_axis_shape, - py::ssize_t dst_axis_stride, + const ssize_t *orthog_shape_and_strides, + ssize_t src_offset, + ssize_t dst_offset, + ssize_t src_axis_shape, + ssize_t src_axis_stride, + ssize_t dst_axis_shape, + ssize_t dst_axis_stride, const std::vector &depends) { sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { @@ -388,11 +387,11 @@ typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)( size_t, const char *, char *, - const py::ssize_t, + const ssize_t, int, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, const std::vector &); template @@ -400,11 +399,11 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q, size_t dst_nelems, const char *src_cp, char *dst_cp, - const py::ssize_t reps, + const ssize_t reps, int src_nd, - const py::ssize_t *src_shape_strides, - py::ssize_t dst_shape, - py::ssize_t dst_stride, + const ssize_t *src_shape_strides, + ssize_t dst_shape, + ssize_t dst_stride, const std::vector &depends) { sycl::event repeat_ev = q.submit([&](sycl::handler &cgh) { diff --git a/dpctl/tensor/libtensor/include/kernels/sorting.hpp b/dpctl/tensor/libtensor/include/kernels/sorting.hpp index e7a024259e..2e8f4f8e91 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting.hpp @@ -24,8 +24,6 @@ #pragma once -#include "pybind11/pybind11.h" - #include #include #include @@ -33,6 +31,8 @@ #include #include +#include "dpctl_tensor_types.hpp" + namespace dpctl { namespace tensor @@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q, sycl::group_barrier(it.get_group()); bool data_in_temp = false; - size_t sorted_size = 1; - while (true) { - const size_t nelems_sorted_so_far = sorted_size * chunk; - if (nelems_sorted_so_far < wg_chunk_size) { - const size_t q = (lid / sorted_size); - const size_t start_1 = - sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); - const size_t end_1 = sycl::min( - start_1 + nelems_sorted_so_far, wg_chunk_size); - const size_t end_2 = - sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); - const size_t offset = chunk * (lid - q * sorted_size); - - if (data_in_temp) { - merge_impl(offset, scratch_space, work_space, start_1, - end_1, end_2, start_1, comp, chunk); - } - else { - merge_impl(offset, work_space, scratch_space, start_1, - end_1, end_2, start_1, comp, chunk); - } - sycl::group_barrier(it.get_group()); - - data_in_temp = !data_in_temp; - sorted_size *= 2; + size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const size_t nelems_sorted_so_far = n_chunks_merged * chunk; + const size_t q = (lid / n_chunks_merged); + const size_t start_1 = + sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); + const size_t end_1 = + sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size); + const size_t end_2 = + sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); + const size_t offset = chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_impl(offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + else { + merge_impl(offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, comp, chunk); } - else - break; + sycl::group_barrier(it.get_group()); } const auto &out_src = (data_in_temp) ? scratch_space : work_space; @@ -752,10 +750,10 @@ typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &, size_t, const char *, char *, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template > @@ -767,10 +765,10 @@ sycl::event stable_sort_axis1_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t sort_arg_offset, - py::ssize_t sort_res_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + @@ -839,10 +837,10 @@ sycl::event stable_argsort_axis1_contig_impl( // number of columns) const char *arg_cp, char *res_cp, - py::ssize_t iter_arg_offset, - py::ssize_t iter_res_offset, - py::ssize_t sort_arg_offset, - py::ssize_t sort_res_offset, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, const std::vector &depends) { const argTy *arg_tp = reinterpret_cast(arg_cp) + diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp index a1c0c7cfb0..415dd8a8d5 100644 --- a/dpctl/tensor/libtensor/include/kernels/where.hpp +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -23,15 +23,13 @@ //===----------------------------------------------------------------------===// #pragma once -#include "pybind11/numpy.h" -#include "pybind11/stl.h" #include #include #include -#include #include #include +#include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/offset_utils.hpp" #include "utils/type_utils.hpp" @@ -45,8 +43,6 @@ namespace kernels namespace search { -namespace py = pybind11; - using namespace dpctl::tensor::offset_utils; using dpctl::tensor::kernels::alignment_utils:: @@ -244,7 +240,7 @@ class WhereStridedFunctor void operator()(sycl::id<1> id) const { size_t gid = id[0]; - auto offsets = indexer(static_cast(gid)); + auto offsets = indexer(static_cast(gid)); using dpctl::tensor::type_utils::convert_impl; bool check = @@ -264,11 +260,11 @@ typedef sycl::event (*where_strided_impl_fn_ptr_t)( const char *, const char *, char *, - const py::ssize_t *, - py::ssize_t, - py::ssize_t, - py::ssize_t, - py::ssize_t, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, const std::vector &); template @@ -279,11 +275,11 @@ sycl::event where_strided_impl(sycl::queue &q, const char *x1_cp, const char *x2_cp, char *dst_cp, - const py::ssize_t *shape_strides, - py::ssize_t x1_offset, - py::ssize_t x2_offset, - py::ssize_t cond_offset, - py::ssize_t dst_offset, + const ssize_t *shape_strides, + ssize_t x1_offset, + ssize_t x2_offset, + ssize_t cond_offset, + ssize_t dst_offset, const std::vector &depends) { const condT *cond_tp = reinterpret_cast(cond_cp); diff --git a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp index 523620737b..c94b89e9a3 100644 --- a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp @@ -27,15 +27,13 @@ #pragma once #include -#include #include #include #include +#include "kernels/dpctl_tensor_types.hpp" #include "utils/strided_iters.hpp" -namespace py = pybind11; - namespace dpctl { namespace tensor @@ -85,7 +83,7 @@ std::vector concat(std::vector lhs, Vs &&...vs) template std::tuple -device_allocate_and_pack(sycl::queue q, +device_allocate_and_pack(sycl::queue &q, std::vector &host_task_events, Vs &&...vs) { @@ -137,35 +135,35 @@ struct NoOpIndexer struct StridedIndexer { StridedIndexer(int _nd, - py::ssize_t _offset, - py::ssize_t const *_packed_shape_strides) + ssize_t _offset, + ssize_t const *_packed_shape_strides) : nd(_nd), starting_offset(_offset), shape_strides(_packed_shape_strides) { } - py::ssize_t operator()(py::ssize_t gid) const + ssize_t operator()(ssize_t gid) const { return compute_offset(gid); } - py::ssize_t operator()(size_t gid) const + ssize_t operator()(size_t gid) const { - return compute_offset(static_cast(gid)); + return compute_offset(static_cast(gid)); } private: int nd; - py::ssize_t starting_offset; - py::ssize_t const *shape_strides; + ssize_t starting_offset; + ssize_t const *shape_strides; - py::ssize_t compute_offset(py::ssize_t gid) const + ssize_t compute_offset(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd); - py::ssize_t relative_offset(0); - _ind.get_displacement( + ssize_t relative_offset(0); + _ind.get_displacement( gid, shape_strides, // shape ptr shape_strides + nd, // strides ptr @@ -178,36 +176,36 @@ struct StridedIndexer struct UnpackedStridedIndexer { UnpackedStridedIndexer(int _nd, - py::ssize_t _offset, - py::ssize_t const *_shape, - py::ssize_t const *_strides) + ssize_t _offset, + ssize_t const *_shape, + ssize_t const *_strides) : nd(_nd), starting_offset(_offset), shape(_shape), strides(_strides) { } - py::ssize_t operator()(py::ssize_t gid) const + ssize_t operator()(ssize_t gid) const { return compute_offset(gid); } - py::ssize_t operator()(size_t gid) const + ssize_t operator()(size_t gid) const { - return compute_offset(static_cast(gid)); + return compute_offset(static_cast(gid)); } private: int nd; - py::ssize_t starting_offset; - py::ssize_t const *shape; - py::ssize_t const *strides; + ssize_t starting_offset; + ssize_t const *shape; + ssize_t const *strides; - py::ssize_t compute_offset(py::ssize_t gid) const + ssize_t compute_offset(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd); py::ssize_t relative_offset(0); - _ind.get_displacement( + _ind.get_displacement( gid, shape, // shape ptr strides, // strides ptr @@ -218,41 +216,39 @@ struct UnpackedStridedIndexer struct Strided1DIndexer { - Strided1DIndexer(py::ssize_t _offset, py::ssize_t _size, py::ssize_t _step) + Strided1DIndexer(ssize_t _offset, ssize_t _size, ssize_t _step) : offset(_offset), size(static_cast(_size)), step(_step) { } - py::ssize_t operator()(size_t gid) const + ssize_t operator()(size_t gid) const { // ensure 0 <= gid < size return offset + std::min(gid, size - 1) * step; } private: - py::ssize_t offset = 0; + ssize_t offset = 0; size_t size = 1; - py::ssize_t step = 1; + ssize_t step = 1; }; struct Strided1DCyclicIndexer { - Strided1DCyclicIndexer(py::ssize_t _offset, - py::ssize_t _size, - py::ssize_t _step) + Strided1DCyclicIndexer(ssize_t _offset, ssize_t _size, ssize_t _step) : offset(_offset), size(static_cast(_size)), step(_step) { } - py::ssize_t operator()(size_t gid) const + ssize_t operator()(size_t gid) const { return offset + (gid % size) * step; } private: - py::ssize_t offset = 0; + ssize_t offset = 0; size_t size = 1; - py::ssize_t step = 1; + ssize_t step = 1; }; template struct TwoOffsets @@ -281,45 +277,45 @@ template struct TwoOffsets struct TwoOffsets_StridedIndexer { TwoOffsets_StridedIndexer(int common_nd, - py::ssize_t first_offset_, - py::ssize_t second_offset_, - py::ssize_t const *_packed_shape_strides) + ssize_t first_offset_, + ssize_t second_offset_, + ssize_t const *_packed_shape_strides) : nd(common_nd), starting_first_offset(first_offset_), starting_second_offset(second_offset_), shape_strides(_packed_shape_strides) { } - TwoOffsets operator()(py::ssize_t gid) const + TwoOffsets operator()(ssize_t gid) const { return compute_offsets(gid); } - TwoOffsets operator()(size_t gid) const + TwoOffsets operator()(size_t gid) const { - return compute_offsets(static_cast(gid)); + return compute_offsets(static_cast(gid)); } private: int nd; - py::ssize_t starting_first_offset; - py::ssize_t starting_second_offset; - py::ssize_t const *shape_strides; + ssize_t starting_first_offset; + ssize_t starting_second_offset; + ssize_t const *shape_strides; - TwoOffsets compute_offsets(py::ssize_t gid) const + TwoOffsets compute_offsets(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd); - py::ssize_t relative_first_offset(0); - py::ssize_t relative_second_offset(0); - _ind.get_displacement( + ssize_t relative_first_offset(0); + ssize_t relative_second_offset(0); + _ind.get_displacement( gid, shape_strides, // shape ptr shape_strides + nd, // strides ptr shape_strides + 2 * nd, // strides ptr relative_first_offset, relative_second_offset); - return TwoOffsets( + return TwoOffsets( starting_first_offset + relative_first_offset, starting_second_offset + relative_second_offset); } @@ -329,9 +325,9 @@ struct TwoZeroOffsets_Indexer { TwoZeroOffsets_Indexer() {} - TwoOffsets operator()(py::ssize_t) const + TwoOffsets operator()(ssize_t) const { - return TwoOffsets(); + return TwoOffsets(); } }; @@ -389,10 +385,10 @@ template struct ThreeOffsets struct ThreeOffsets_StridedIndexer { ThreeOffsets_StridedIndexer(int common_nd, - py::ssize_t first_offset_, - py::ssize_t second_offset_, - py::ssize_t third_offset_, - py::ssize_t const *_packed_shape_strides) + ssize_t first_offset_, + ssize_t second_offset_, + ssize_t third_offset_, + ssize_t const *_packed_shape_strides) : nd(common_nd), starting_first_offset(first_offset_), starting_second_offset(second_offset_), starting_third_offset(third_offset_), @@ -400,32 +396,32 @@ struct ThreeOffsets_StridedIndexer { } - ThreeOffsets operator()(py::ssize_t gid) const + ThreeOffsets operator()(ssize_t gid) const { return compute_offsets(gid); } - ThreeOffsets operator()(size_t gid) const + ThreeOffsets operator()(size_t gid) const { - return compute_offsets(static_cast(gid)); + return compute_offsets(static_cast(gid)); } private: int nd; - py::ssize_t starting_first_offset; - py::ssize_t starting_second_offset; - py::ssize_t starting_third_offset; - py::ssize_t const *shape_strides; + ssize_t starting_first_offset; + ssize_t starting_second_offset; + ssize_t starting_third_offset; + ssize_t const *shape_strides; - ThreeOffsets compute_offsets(py::ssize_t gid) const + ThreeOffsets compute_offsets(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd); - py::ssize_t relative_first_offset(0); - py::ssize_t relative_second_offset(0); - py::ssize_t relative_third_offset(0); - _ind.get_displacement( + ssize_t relative_first_offset(0); + ssize_t relative_second_offset(0); + ssize_t relative_third_offset(0); + _ind.get_displacement( gid, shape_strides, // shape ptr shape_strides + nd, // strides ptr @@ -433,7 +429,7 @@ struct ThreeOffsets_StridedIndexer shape_strides + 3 * nd, // strides ptr relative_first_offset, relative_second_offset, relative_third_offset); - return ThreeOffsets( + return ThreeOffsets( starting_first_offset + relative_first_offset, starting_second_offset + relative_second_offset, starting_third_offset + relative_third_offset); @@ -450,6 +446,32 @@ struct ThreeZeroOffsets_Indexer } }; +template +struct ThreeOffsets_CombinedIndexer +{ +private: + FirstIndexerT first_indexer_; + SecondIndexerT second_indexer_; + ThirdIndexerT third_indexer_; + +public: + ThreeOffsets_CombinedIndexer(const FirstIndexerT &first_indexer, + const SecondIndexerT &second_indexer, + const ThirdIndexerT &third_indexer) + : first_indexer_(first_indexer), second_indexer_(second_indexer), + third_indexer_(third_indexer) + { + } + + ThreeOffsets operator()(ssize_t gid) const + { + return ThreeOffsets(first_indexer_(gid), second_indexer_(gid), + third_indexer_(gid)); + } +}; + template struct FourOffsets { FourOffsets() @@ -492,11 +514,11 @@ template struct FourOffsets struct FourOffsets_StridedIndexer { FourOffsets_StridedIndexer(int common_nd, - py::ssize_t first_offset_, - py::ssize_t second_offset_, - py::ssize_t third_offset_, - py::ssize_t fourth_offset_, - py::ssize_t const *_packed_shape_strides) + ssize_t first_offset_, + ssize_t second_offset_, + ssize_t third_offset_, + ssize_t fourth_offset_, + ssize_t const *_packed_shape_strides) : nd(common_nd), starting_first_offset(first_offset_), starting_second_offset(second_offset_), starting_third_offset(third_offset_), @@ -505,34 +527,34 @@ struct FourOffsets_StridedIndexer { } - FourOffsets operator()(py::ssize_t gid) const + FourOffsets operator()(ssize_t gid) const { return compute_offsets(gid); } - FourOffsets operator()(size_t gid) const + FourOffsets operator()(size_t gid) const { - return compute_offsets(static_cast(gid)); + return compute_offsets(static_cast(gid)); } private: int nd; - py::ssize_t starting_first_offset; - py::ssize_t starting_second_offset; - py::ssize_t starting_third_offset; - py::ssize_t starting_fourth_offset; - py::ssize_t const *shape_strides; + ssize_t starting_first_offset; + ssize_t starting_second_offset; + ssize_t starting_third_offset; + ssize_t starting_fourth_offset; + ssize_t const *shape_strides; - FourOffsets compute_offsets(py::ssize_t gid) const + FourOffsets compute_offsets(ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; CIndexer_vector _ind(nd); - py::ssize_t relative_first_offset(0); - py::ssize_t relative_second_offset(0); - py::ssize_t relative_third_offset(0); - py::ssize_t relative_fourth_offset(0); - _ind.get_displacement( + ssize_t relative_first_offset(0); + ssize_t relative_second_offset(0); + ssize_t relative_third_offset(0); + ssize_t relative_fourth_offset(0); + _ind.get_displacement( gid, shape_strides, // shape ptr shape_strides + nd, // strides ptr @@ -541,7 +563,7 @@ struct FourOffsets_StridedIndexer shape_strides + 4 * nd, // strides ptr relative_first_offset, relative_second_offset, relative_third_offset, relative_fourth_offset); - return FourOffsets( + return FourOffsets( starting_first_offset + relative_first_offset, starting_second_offset + relative_second_offset, starting_third_offset + relative_third_offset, @@ -553,26 +575,26 @@ struct FourZeroOffsets_Indexer { FourZeroOffsets_Indexer() {} - FourOffsets operator()(py::ssize_t) const + FourOffsets operator()(ssize_t) const { - return FourOffsets(); + return FourOffsets(); } }; struct NthStrideOffset { NthStrideOffset(int common_nd, - py::ssize_t const *_offsets, - py::ssize_t const *_packed_shape_strides) + ssize_t const *_offsets, + ssize_t const *_packed_shape_strides) : _ind(common_nd), nd(common_nd), offsets(_offsets), shape_strides(_packed_shape_strides) { } - size_t operator()(py::ssize_t gid, int n) const + size_t operator()(ssize_t gid, int n) const { - py::ssize_t relative_offset(0); - _ind.get_displacement( + ssize_t relative_offset(0); + _ind.get_displacement( gid, shape_strides, shape_strides + ((n + 1) * nd), relative_offset); @@ -580,29 +602,29 @@ struct NthStrideOffset } private: - dpctl::tensor::strides::CIndexer_vector _ind; + dpctl::tensor::strides::CIndexer_vector _ind; int nd; - py::ssize_t const *offsets; - py::ssize_t const *shape_strides; + ssize_t const *offsets; + ssize_t const *shape_strides; }; template struct FixedDimStridedIndexer { - FixedDimStridedIndexer(const std::array _shape, - const std::array _strides, - py::ssize_t _offset) + FixedDimStridedIndexer(const std::array _shape, + const std::array _strides, + ssize_t _offset) : _ind(_shape), strides(_strides), starting_offset(_offset) { } size_t operator()(size_t gid) const { - dpctl::tensor::strides::CIndexer_array local_indexer( + dpctl::tensor::strides::CIndexer_array local_indexer( std::move(_ind)); local_indexer.set(gid); auto mi = local_indexer.get(); - py::ssize_t relative_offset = 0; + ssize_t relative_offset = 0; #pragma unroll for (int i = 0; i < nd; ++i) { @@ -612,112 +634,110 @@ template struct FixedDimStridedIndexer } private: - dpctl::tensor::strides::CIndexer_array _ind; + dpctl::tensor::strides::CIndexer_array _ind; - const std::array strides; - py::ssize_t starting_offset; + const std::array strides; + ssize_t starting_offset; }; template struct TwoOffsets_FixedDimStridedIndexer { - TwoOffsets_FixedDimStridedIndexer( - const std::array _shape, - const std::array _strides1, - const std::array _strides2, - py::ssize_t _offset1, - py::ssize_t _offset2) + TwoOffsets_FixedDimStridedIndexer(const std::array _shape, + const std::array _strides1, + const std::array _strides2, + ssize_t _offset1, + ssize_t _offset2) : _ind(_shape), strides1(_strides1), strides2(_strides2), starting_offset1(_offset1), starting_offset2(_offset2) { } - TwoOffsets operator()(size_t gid) const + TwoOffsets operator()(size_t gid) const { - dpctl::tensor::strides::CIndexer_array local_indexer( + dpctl::tensor::strides::CIndexer_array local_indexer( std::move(_ind)); local_indexer.set(gid); auto mi = local_indexer.get(); - py::ssize_t relative_offset1 = 0; + ssize_t relative_offset1 = 0; #pragma unroll for (int i = 0; i < nd; ++i) { relative_offset1 += mi[i] * strides1[i]; } - py::ssize_t relative_offset2 = 0; + ssize_t relative_offset2 = 0; #pragma unroll for (int i = 0; i < nd; ++i) { relative_offset2 += mi[i] * strides2[i]; } - return TwoOffsets(starting_offset1 + relative_offset1, - starting_offset2 + relative_offset2); + return TwoOffsets(starting_offset1 + relative_offset1, + starting_offset2 + relative_offset2); } private: - dpctl::tensor::strides::CIndexer_array _ind; + dpctl::tensor::strides::CIndexer_array _ind; - const std::array strides1; - const std::array strides2; - py::ssize_t starting_offset1; - py::ssize_t starting_offset2; + const std::array strides1; + const std::array strides2; + ssize_t starting_offset1; + ssize_t starting_offset2; }; template struct ThreeOffsets_FixedDimStridedIndexer { - ThreeOffsets_FixedDimStridedIndexer( - const std::array _shape, - const std::array _strides1, - const std::array _strides2, - const std::array _strides3, - py::ssize_t _offset1, - py::ssize_t _offset2, - py::ssize_t _offset3) + ThreeOffsets_FixedDimStridedIndexer(const std::array _shape, + const std::array _strides1, + const std::array _strides2, + const std::array _strides3, + ssize_t _offset1, + ssize_t _offset2, + ssize_t _offset3) : _ind(_shape), strides1(_strides1), strides2(_strides2), strides3(_strides3), starting_offset1(_offset1), starting_offset2(_offset2), starting_offset3(_offset3) { } - ThreeOffsets operator()(size_t gid) const + ThreeOffsets operator()(size_t gid) const { - dpctl::tensor::strides::CIndexer_array local_indexer( + dpctl::tensor::strides::CIndexer_array local_indexer( std::move(_ind)); local_indexer.set(gid); auto mi = local_indexer.get(); - py::ssize_t relative_offset1 = 0; + ssize_t relative_offset1 = 0; #pragma unroll for (int i = 0; i < nd; ++i) { relative_offset1 += mi[i] * strides1[i]; } - py::ssize_t relative_offset2 = 0; + ssize_t relative_offset2 = 0; #pragma unroll for (int i = 0; i < nd; ++i) { relative_offset2 += mi[i] * strides2[i]; } - py::ssize_t relative_offset3 = 0; + ssize_t relative_offset3 = 0; #pragma unroll for (int i = 0; i < nd; ++i) { relative_offset3 += mi[i] * strides3[i]; } - return ThreeOffsets(starting_offset1 + relative_offset1, - starting_offset2 + relative_offset2, - starting_offset3 + relative_offset3); + return ThreeOffsets(starting_offset1 + relative_offset1, + starting_offset2 + relative_offset2, + starting_offset3 + relative_offset3); } private: - dpctl::tensor::strides::CIndexer_array _ind; - - const std::array strides1; - const std::array strides2; - const std::array strides3; - py::ssize_t starting_offset1; - py::ssize_t starting_offset2; - py::ssize_t starting_offset3; + dpctl::tensor::strides::CIndexer_array _ind; + + const std::array strides1; + const std::array strides2; + const std::array strides3; + ssize_t starting_offset1; + ssize_t starting_offset2; + ssize_t starting_offset3; }; } // namespace offset_utils diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index af031a963b..252192f507 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -25,8 +25,7 @@ #pragma once #include "dpctl4pybind11.hpp" -#include -#include +#include "type_dispatch_building.hpp" namespace dpctl { @@ -36,129 +35,6 @@ namespace tensor namespace type_dispatch { -enum class typenum_t : int -{ - BOOL = 0, - INT8, // 1 - UINT8, - INT16, - UINT16, - INT32, // 5 - UINT32, - INT64, - UINT64, - HALF, - FLOAT, // 10 - DOUBLE, - CFLOAT, - CDOUBLE, // 13 -}; -constexpr int num_types = 14; // number of elements in typenum_t - -template - typename factory, - int _num_types> -class DispatchTableBuilder -{ -private: - template - const std::vector row_per_dst_type() const - { - std::vector per_dstTy = { - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory{}.get(), - factory>{}.get(), - factory>{}.get()}; - assert(per_dstTy.size() == _num_types); - return per_dstTy; - } - -public: - DispatchTableBuilder() = default; - ~DispatchTableBuilder() = default; - - void populate_dispatch_table(funcPtrT table[][_num_types]) const - { - const auto map_by_dst_type = {row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type(), - row_per_dst_type>(), - row_per_dst_type>()}; - assert(map_by_dst_type.size() == _num_types); - int dst_id = 0; - for (auto &row : map_by_dst_type) { - int src_id = 0; - for (auto &fn_ptr : row) { - table[dst_id][src_id] = fn_ptr; - ++src_id; - } - ++dst_id; - } - } -}; - -template - typename factory, - int _num_types> -class DispatchVectorBuilder -{ -private: - template const funcPtrT func_per_type() const - { - funcPtrT f = factory{}.get(); - return f; - } - -public: - DispatchVectorBuilder() = default; - ~DispatchVectorBuilder() = default; - - void populate_dispatch_vector(funcPtrT vector[]) const - { - const auto fn_map_by_type = {func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type(), - func_per_type>(), - func_per_type>()}; - assert(fn_map_by_type.size() == _num_types); - int ty_id = 0; - for (auto &fn : fn_map_by_type) { - vector[ty_id] = fn; - ++ty_id; - } - } -}; - struct usm_ndarray_types { @@ -250,136 +126,6 @@ struct usm_ndarray_types } }; -/*! @brief struct to define result_type typename for Ty == ArgTy */ -template -struct TypeMapResultEntry : std::bool_constant> -{ - using result_type = ResTy; -}; - -/*! @brief struct to define result_type typename for Ty1 == ArgTy1 && Ty2 == - * ArgTy2 */ -template -struct BinaryTypeMapResultEntry - : std::bool_constant, - std::is_same>> -{ - using result_type = ResTy; -}; - -/*! @brief fall-through struct with specified result_type, usually void */ -template struct DefaultResultEntry : std::true_type -{ - using result_type = Ty; -}; - -/*! @brief Utility struct to convert C++ type into typeid integer */ -template struct GetTypeid -{ - int get() - { - if constexpr (std::is_same_v) { - return static_cast(typenum_t::BOOL); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::INT8); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::UINT8); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::INT16); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::UINT16); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::INT32); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::UINT32); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::INT64); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::UINT64); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::HALF); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::FLOAT); - } - else if constexpr (std::is_same_v) { - return static_cast(typenum_t::DOUBLE); - } - else if constexpr (std::is_same_v>) { - return static_cast(typenum_t::CFLOAT); - } - else if constexpr (std::is_same_v>) { - return static_cast(typenum_t::CDOUBLE); - } - else if constexpr (std::is_same_v) { // special token - return -1; - } - - assert(("Unsupported type T", false)); - return -2; - } -}; - -/*! @brief Class to generate vector of null function pointers */ -template struct NullPtrVector -{ - - using value_type = FunPtrT; - using const_reference = value_type const &; - - NullPtrVector() : val(nullptr) {} - - const_reference operator[](int) const - { - return val; - } - -private: - value_type val; -}; - -/*! @brief Class to generate table of null function pointers */ -template struct NullPtrTable -{ - using value_type = NullPtrVector; - using const_reference = value_type const &; - - NullPtrTable() : val() {} - - const_reference operator[](int) const - { - return val; - } - -private: - value_type val; -}; - -template -struct TypePairDefinedEntry : std::bool_constant && - std::is_same_v> -{ - static constexpr bool is_defined = true; -}; - -struct NotDefinedEntry : std::true_type -{ - static constexpr bool is_defined = false; -}; - } // namespace type_dispatch } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp new file mode 100644 index 0000000000..11cccdfb56 --- /dev/null +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp @@ -0,0 +1,294 @@ +//===--type_dispatch.cpp - Type-dispatch table building utils ----*-C++-*- ===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines class to implement dispatch tables for pair of types +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace dpctl +{ +namespace tensor +{ + +namespace type_dispatch +{ + +enum class typenum_t : int +{ + BOOL = 0, + INT8, // 1 + UINT8, + INT16, + UINT16, + INT32, // 5 + UINT32, + INT64, + UINT64, + HALF, + FLOAT, // 10 + DOUBLE, + CFLOAT, + CDOUBLE, // 13 +}; +constexpr int num_types = 14; // number of elements in typenum_t + +template + typename factory, + int _num_types> +class DispatchTableBuilder +{ +private: + template + const std::vector row_per_dst_type() const + { + std::vector per_dstTy = { + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory{}.get(), + factory>{}.get(), + factory>{}.get()}; + assert(per_dstTy.size() == _num_types); + return per_dstTy; + } + +public: + DispatchTableBuilder() = default; + ~DispatchTableBuilder() = default; + + void populate_dispatch_table(funcPtrT table[][_num_types]) const + { + const auto map_by_dst_type = {row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type(), + row_per_dst_type>(), + row_per_dst_type>()}; + assert(map_by_dst_type.size() == _num_types); + int dst_id = 0; + for (auto &row : map_by_dst_type) { + int src_id = 0; + for (auto &fn_ptr : row) { + table[dst_id][src_id] = fn_ptr; + ++src_id; + } + ++dst_id; + } + } +}; + +template + typename factory, + int _num_types> +class DispatchVectorBuilder +{ +private: + template const funcPtrT func_per_type() const + { + funcPtrT f = factory{}.get(); + return f; + } + +public: + DispatchVectorBuilder() = default; + ~DispatchVectorBuilder() = default; + + void populate_dispatch_vector(funcPtrT vector[]) const + { + const auto fn_map_by_type = {func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type(), + func_per_type>(), + func_per_type>()}; + assert(fn_map_by_type.size() == _num_types); + int ty_id = 0; + for (auto &fn : fn_map_by_type) { + vector[ty_id] = fn; + ++ty_id; + } + } +}; + +/*! @brief struct to define result_type typename for Ty == ArgTy */ +template +struct TypeMapResultEntry : std::bool_constant> +{ + using result_type = ResTy; +}; + +/*! @brief struct to define result_type typename for Ty1 == ArgTy1 && Ty2 == + * ArgTy2 */ +template +struct BinaryTypeMapResultEntry + : std::bool_constant, + std::is_same>> +{ + using result_type = ResTy; +}; + +/*! @brief fall-through struct with specified result_type, usually void */ +template struct DefaultResultEntry : std::true_type +{ + using result_type = Ty; +}; + +/*! @brief Utility struct to convert C++ type into typeid integer */ +template struct GetTypeid +{ + int get() + { + if constexpr (std::is_same_v) { + return static_cast(typenum_t::BOOL); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT8); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT8); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT16); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT16); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT32); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT32); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT64); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT64); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::HALF); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::FLOAT); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::DOUBLE); + } + else if constexpr (std::is_same_v>) { + return static_cast(typenum_t::CFLOAT); + } + else if constexpr (std::is_same_v>) { + return static_cast(typenum_t::CDOUBLE); + } + else if constexpr (std::is_same_v) { // special token + return -1; + } + + assert(("Unsupported type T", false)); + return -2; + } +}; + +/*! @brief Class to generate vector of null function pointers */ +template struct NullPtrVector +{ + + using value_type = FunPtrT; + using const_reference = value_type const &; + + NullPtrVector() : val(nullptr) {} + + const_reference operator[](int) const + { + return val; + } + +private: + value_type val; +}; + +/*! @brief Class to generate table of null function pointers */ +template struct NullPtrTable +{ + using value_type = NullPtrVector; + using const_reference = value_type const &; + + NullPtrTable() : val() {} + + const_reference operator[](int) const + { + return val; + } + +private: + value_type val; +}; + +template +struct TypePairDefinedEntry : std::bool_constant && + std::is_same_v> +{ + static constexpr bool is_defined = true; +}; + +struct NotDefinedEntry : std::true_type +{ + static constexpr bool is_defined = false; +}; + +} // namespace type_dispatch + +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/boolean_reductions.cpp b/dpctl/tensor/libtensor/source/boolean_reductions.cpp index 32deab6da9..b7ca433367 100644 --- a/dpctl/tensor/libtensor/source/boolean_reductions.cpp +++ b/dpctl/tensor/libtensor/source/boolean_reductions.cpp @@ -37,10 +37,13 @@ #include "dpctl4pybind11.hpp" #include "kernels/boolean_reductions.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "utils/type_utils.hpp" namespace py = pybind11; +static_assert(std::is_same_v); + namespace dpctl { namespace tensor diff --git a/dpctl/tensor/libtensor/source/clip.cpp b/dpctl/tensor/libtensor/source/clip.cpp index ac494c19ae..96af65771d 100644 --- a/dpctl/tensor/libtensor/source/clip.cpp +++ b/dpctl/tensor/libtensor/source/clip.cpp @@ -24,12 +24,12 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include #include #include +#include #include #include "clip.hpp" diff --git a/dpctl/tensor/libtensor/source/clip.hpp b/dpctl/tensor/libtensor/source/clip.hpp index d4b8af2cf5..e8ed1e83fb 100644 --- a/dpctl/tensor/libtensor/source/clip.hpp +++ b/dpctl/tensor/libtensor/source/clip.hpp @@ -24,7 +24,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp deleted file mode 100644 index 9ab7c0807c..0000000000 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ /dev/null @@ -1,5155 +0,0 @@ -//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions, -/// specifically functions for elementwise operations. -//===----------------------------------------------------------------------===// - -#include "dpctl4pybind11.hpp" -#include -#include -#include -#include -#include - -#include "elementwise_functions.hpp" -#include "utils/type_dispatch.hpp" - -#include "kernels/elementwise_functions/abs.hpp" -#include "kernels/elementwise_functions/acos.hpp" -#include "kernels/elementwise_functions/acosh.hpp" -#include "kernels/elementwise_functions/add.hpp" -#include "kernels/elementwise_functions/asin.hpp" -#include "kernels/elementwise_functions/asinh.hpp" -#include "kernels/elementwise_functions/atan.hpp" -#include "kernels/elementwise_functions/atan2.hpp" -#include "kernels/elementwise_functions/atanh.hpp" -#include "kernels/elementwise_functions/bitwise_and.hpp" -#include "kernels/elementwise_functions/bitwise_invert.hpp" -#include "kernels/elementwise_functions/bitwise_left_shift.hpp" -#include "kernels/elementwise_functions/bitwise_or.hpp" -#include "kernels/elementwise_functions/bitwise_right_shift.hpp" -#include "kernels/elementwise_functions/bitwise_xor.hpp" -#include "kernels/elementwise_functions/cbrt.hpp" -#include "kernels/elementwise_functions/ceil.hpp" -#include "kernels/elementwise_functions/conj.hpp" -#include "kernels/elementwise_functions/copysign.hpp" -#include "kernels/elementwise_functions/cos.hpp" -#include "kernels/elementwise_functions/cosh.hpp" -#include "kernels/elementwise_functions/equal.hpp" -#include "kernels/elementwise_functions/exp.hpp" -#include "kernels/elementwise_functions/exp2.hpp" -#include "kernels/elementwise_functions/expm1.hpp" -#include "kernels/elementwise_functions/floor.hpp" -#include "kernels/elementwise_functions/floor_divide.hpp" -#include "kernels/elementwise_functions/greater.hpp" -#include "kernels/elementwise_functions/greater_equal.hpp" -#include "kernels/elementwise_functions/hypot.hpp" -#include "kernels/elementwise_functions/imag.hpp" -#include "kernels/elementwise_functions/isfinite.hpp" -#include "kernels/elementwise_functions/isinf.hpp" -#include "kernels/elementwise_functions/isnan.hpp" -#include "kernels/elementwise_functions/less.hpp" -#include "kernels/elementwise_functions/less_equal.hpp" -#include "kernels/elementwise_functions/log.hpp" -#include "kernels/elementwise_functions/log10.hpp" -#include "kernels/elementwise_functions/log1p.hpp" -#include "kernels/elementwise_functions/log2.hpp" -#include "kernels/elementwise_functions/logaddexp.hpp" -#include "kernels/elementwise_functions/logical_and.hpp" -#include "kernels/elementwise_functions/logical_not.hpp" -#include "kernels/elementwise_functions/logical_or.hpp" -#include "kernels/elementwise_functions/logical_xor.hpp" -#include "kernels/elementwise_functions/maximum.hpp" -#include "kernels/elementwise_functions/minimum.hpp" -#include "kernels/elementwise_functions/multiply.hpp" -#include "kernels/elementwise_functions/negative.hpp" -#include "kernels/elementwise_functions/not_equal.hpp" -#include "kernels/elementwise_functions/positive.hpp" -#include "kernels/elementwise_functions/pow.hpp" -#include "kernels/elementwise_functions/proj.hpp" -#include "kernels/elementwise_functions/real.hpp" -#include "kernels/elementwise_functions/remainder.hpp" -#include "kernels/elementwise_functions/round.hpp" -#include "kernels/elementwise_functions/rsqrt.hpp" -#include "kernels/elementwise_functions/sign.hpp" -#include "kernels/elementwise_functions/signbit.hpp" -#include "kernels/elementwise_functions/sin.hpp" -#include "kernels/elementwise_functions/sinh.hpp" -#include "kernels/elementwise_functions/sqrt.hpp" -#include "kernels/elementwise_functions/square.hpp" -#include "kernels/elementwise_functions/subtract.hpp" -#include "kernels/elementwise_functions/tan.hpp" -#include "kernels/elementwise_functions/tanh.hpp" -#include "kernels/elementwise_functions/true_divide.hpp" -#include "kernels/elementwise_functions/trunc.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -namespace td_ns = dpctl::tensor::type_dispatch; - -py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t) -{ - switch (dst_typenum_t) { - case td_ns::typenum_t::BOOL: - return py::dtype("?"); - case td_ns::typenum_t::INT8: - return py::dtype("i1"); - case td_ns::typenum_t::UINT8: - return py::dtype("u1"); - case td_ns::typenum_t::INT16: - return py::dtype("i2"); - case td_ns::typenum_t::UINT16: - return py::dtype("u2"); - case td_ns::typenum_t::INT32: - return py::dtype("i4"); - case td_ns::typenum_t::UINT32: - return py::dtype("u4"); - case td_ns::typenum_t::INT64: - return py::dtype("i8"); - case td_ns::typenum_t::UINT64: - return py::dtype("u8"); - case td_ns::typenum_t::HALF: - return py::dtype("f2"); - case td_ns::typenum_t::FLOAT: - return py::dtype("f4"); - case td_ns::typenum_t::DOUBLE: - return py::dtype("f8"); - case td_ns::typenum_t::CFLOAT: - return py::dtype("c8"); - case td_ns::typenum_t::CDOUBLE: - return py::dtype("c16"); - default: - throw py::value_error("Unrecognized dst_typeid"); - } -} - -int _result_typeid(int arg_typeid, const int *fn_output_id) -{ - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) { - throw py::value_error("Input typeid " + std::to_string(arg_typeid) + - " is outside of expected bounds."); - } - - return fn_output_id[arg_typeid]; -} - -namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; -using ew_cmn_ns::binary_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_strided_impl_fn_ptr_t; -using ew_cmn_ns::unary_contig_impl_fn_ptr_t; -using ew_cmn_ns::unary_strided_impl_fn_ptr_t; - -using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; - -// U01: ==== ABS (x) -namespace impl -{ - -namespace abs_fn_ns = dpctl::tensor::kernels::abs; - -static unary_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; -static int abs_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - abs_strided_dispatch_vector[td_ns::num_types]; - -void populate_abs_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = abs_fn_ns; - - using fn_ns::AbsContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); - - using fn_ns::AbsStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); - - using fn_ns::AbsTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(abs_output_typeid_vector); -}; - -} // namespace impl - -// U02: ==== ACOS (x) -namespace impl -{ - -namespace acos_fn_ns = dpctl::tensor::kernels::acos; - -static unary_contig_impl_fn_ptr_t acos_contig_dispatch_vector[td_ns::num_types]; -static int acos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acos_strided_dispatch_vector[td_ns::num_types]; - -void populate_acos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acos_fn_ns; - - using fn_ns::AcosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acos_contig_dispatch_vector); - - using fn_ns::AcosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acos_strided_dispatch_vector); - - using fn_ns::AcosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acos_output_typeid_vector); -} - -} // namespace impl - -// U03: ===== ACOSH (x) -namespace impl -{ - -namespace acosh_fn_ns = dpctl::tensor::kernels::acosh; - -static unary_contig_impl_fn_ptr_t - acosh_contig_dispatch_vector[td_ns::num_types]; -static int acosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_acosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acosh_fn_ns; - - using fn_ns::AcoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acosh_contig_dispatch_vector); - - using fn_ns::AcoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acosh_strided_dispatch_vector); - - using fn_ns::AcoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acosh_output_typeid_vector); -} - -} // namespace impl - -// B01: ===== ADD (x1, x2) -namespace impl -{ -namespace add_fn_ns = dpctl::tensor::kernels::add; - -static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int add_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// add(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -// add(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - add_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - add_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - add_inplace_row_matrix_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_add_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = add_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::AddTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(add_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::AddStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(add_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::AddContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(add_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::AddContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - AddContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - add_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::AddContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - AddContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - add_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::AddInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(add_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::AddInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(add_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::AddInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U04: ===== ASIN (x) -namespace impl -{ - -namespace asin_fn_ns = dpctl::tensor::kernels::asin; - -static unary_contig_impl_fn_ptr_t asin_contig_dispatch_vector[td_ns::num_types]; -static int asin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asin_strided_dispatch_vector[td_ns::num_types]; - -void populate_asin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asin_fn_ns; - - using fn_ns::AsinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asin_contig_dispatch_vector); - - using fn_ns::AsinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asin_strided_dispatch_vector); - - using fn_ns::AsinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asin_output_typeid_vector); -} - -} // namespace impl - -// U05: ===== ASINH (x) -namespace impl -{ - -namespace asinh_fn_ns = dpctl::tensor::kernels::asinh; - -static unary_contig_impl_fn_ptr_t - asinh_contig_dispatch_vector[td_ns::num_types]; -static int asinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_asinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asinh_fn_ns; - - using fn_ns::AsinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asinh_contig_dispatch_vector); - - using fn_ns::AsinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asinh_strided_dispatch_vector); - - using fn_ns::AsinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asinh_output_typeid_vector); -} - -} // namespace impl - -// U06: ===== ATAN (x) -namespace impl -{ - -namespace atan_fn_ns = dpctl::tensor::kernels::atan; - -static unary_contig_impl_fn_ptr_t atan_contig_dispatch_vector[td_ns::num_types]; -static int atan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atan_strided_dispatch_vector[td_ns::num_types]; - -void populate_atan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atan_fn_ns; - - using fn_ns::AtanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atan_contig_dispatch_vector); - - using fn_ns::AtanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atan_strided_dispatch_vector); - - using fn_ns::AtanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atan_output_typeid_vector); -} - -} // namespace impl - -// B02: ===== ATAN2 (x1, x2) -namespace impl -{ -namespace atan2_fn_ns = dpctl::tensor::kernels::atan2; - -static binary_contig_impl_fn_ptr_t - atan2_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int atan2_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - atan2_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_atan2_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = atan2_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::Atan2TypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(atan2_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::Atan2StridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(atan2_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::Atan2ContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(atan2_contig_dispatch_table); -}; - -} // namespace impl - -// U07: ===== ATANH (x) -namespace impl -{ - -namespace atanh_fn_ns = dpctl::tensor::kernels::atanh; - -static unary_contig_impl_fn_ptr_t - atanh_contig_dispatch_vector[td_ns::num_types]; -static int atanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_atanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atanh_fn_ns; - - using fn_ns::AtanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atanh_contig_dispatch_vector); - - using fn_ns::AtanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atanh_strided_dispatch_vector); - - using fn_ns::AtanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atanh_output_typeid_vector); -} - -} // namespace impl - -// B03: ===== BITWISE_AND (x1, x2) -namespace impl -{ -namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; - -static binary_contig_impl_fn_ptr_t - bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_and_contig_dispatch_table); -}; - -} // namespace impl - -// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_left_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_left_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_left_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_left_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_left_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseLeftShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_left_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseLeftShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_left_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseLeftShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_left_shift_contig_dispatch_table); -}; - -} // namespace impl - -// U08: ===== BITWISE_INVERT (x) -namespace impl -{ - -namespace bitwise_invert_fn_ns = dpctl::tensor::kernels::bitwise_invert; - -static unary_contig_impl_fn_ptr_t - bitwise_invert_contig_dispatch_vector[td_ns::num_types]; -static int bitwise_invert_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - bitwise_invert_strided_dispatch_vector[td_ns::num_types]; - -void populate_bitwise_invert_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_invert_fn_ns; - - using fn_ns::BitwiseInvertContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(bitwise_invert_contig_dispatch_vector); - - using fn_ns::BitwiseInvertStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(bitwise_invert_strided_dispatch_vector); - - using fn_ns::BitwiseInvertTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(bitwise_invert_output_typeid_vector); -}; - -} // namespace impl - -// B05: ===== BITWISE_OR (x1, x2) -namespace impl -{ -namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; - -static binary_contig_impl_fn_ptr_t - bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_or_contig_dispatch_table); -}; -} // namespace impl - -// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_right_shift_fn_ns = - dpctl::tensor::kernels::bitwise_right_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_right_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_right_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_right_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_right_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_right_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseRightShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_right_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseRightShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_right_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseRightShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_right_shift_contig_dispatch_table); -}; - -} // namespace impl - -// B07: ===== BITWISE_XOR (x1, x2) -namespace impl -{ -namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; - -static binary_contig_impl_fn_ptr_t - bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_xor_contig_dispatch_table); -}; -} // namespace impl - -// U09: ==== CEIL (x) -namespace impl -{ - -namespace ceil_fn_ns = dpctl::tensor::kernels::ceil; - -static unary_contig_impl_fn_ptr_t ceil_contig_dispatch_vector[td_ns::num_types]; -static int ceil_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - ceil_strided_dispatch_vector[td_ns::num_types]; - -void populate_ceil_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = ceil_fn_ns; - - using fn_ns::CeilContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(ceil_contig_dispatch_vector); - - using fn_ns::CeilStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(ceil_strided_dispatch_vector); - - using fn_ns::CeilTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(ceil_output_typeid_vector); -} - -} // namespace impl - -// U10: ==== CONJ (x) -namespace impl -{ - -namespace conj_fn_ns = dpctl::tensor::kernels::conj; - -static unary_contig_impl_fn_ptr_t conj_contig_dispatch_vector[td_ns::num_types]; -static int conj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - conj_strided_dispatch_vector[td_ns::num_types]; - -void populate_conj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = conj_fn_ns; - - using fn_ns::ConjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(conj_contig_dispatch_vector); - - using fn_ns::ConjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(conj_strided_dispatch_vector); - - using fn_ns::ConjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(conj_output_typeid_vector); -} -} // namespace impl - -// U11: ==== COS (x) -namespace impl -{ - -namespace cos_fn_ns = dpctl::tensor::kernels::cos; - -static unary_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; -static int cos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cos_strided_dispatch_vector[td_ns::num_types]; - -void populate_cos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cos_fn_ns; - - using fn_ns::CosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); - - using fn_ns::CosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); - - using fn_ns::CosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cos_output_typeid_vector); -} - -} // namespace impl - -// U12: ==== COSH (x) -namespace impl -{ - -namespace cosh_fn_ns = dpctl::tensor::kernels::cosh; - -static unary_contig_impl_fn_ptr_t cosh_contig_dispatch_vector[td_ns::num_types]; -static int cosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_cosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cosh_fn_ns; - - using fn_ns::CoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cosh_contig_dispatch_vector); - - using fn_ns::CoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cosh_strided_dispatch_vector); - - using fn_ns::CoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cosh_output_typeid_vector); -} - -} // namespace impl - -// B08: ==== DIVIDE (x1, x2) -namespace impl -{ -namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; - -static binary_contig_impl_fn_ptr_t - true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_inplace_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// divide(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - true_divide_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// divide(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - true_divide_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - true_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - true_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - true_divide_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_true_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = true_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(true_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::TrueDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::TrueDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - TrueDivideContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - true_divide_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - TrueDivideContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideInplaceTypeMapFactory; - DispatchTableBuilder dtb6; - dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::TrueDivideInplaceStridedFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(true_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::TrueDivideInplaceContigFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(true_divide_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::TrueDivideInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb9; - dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// B09: ==== EQUAL (x1, x2) -namespace impl -{ -namespace equal_fn_ns = dpctl::tensor::kernels::equal; - -static binary_contig_impl_fn_ptr_t - equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::EqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::EqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::EqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(equal_contig_dispatch_table); -}; -} // namespace impl - -// U13: ==== EXP (x) -namespace impl -{ - -namespace exp_fn_ns = dpctl::tensor::kernels::exp; - -static unary_contig_impl_fn_ptr_t exp_contig_dispatch_vector[td_ns::num_types]; -static int exp_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp_fn_ns; - - using fn_ns::ExpContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp_contig_dispatch_vector); - - using fn_ns::ExpStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp_strided_dispatch_vector); - - using fn_ns::ExpTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp_output_typeid_vector); -} - -} // namespace impl - -// U14: ==== EXPM1 (x) -namespace impl -{ - -namespace expm1_fn_ns = dpctl::tensor::kernels::expm1; - -static unary_contig_impl_fn_ptr_t - expm1_contig_dispatch_vector[td_ns::num_types]; -static int expm1_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - expm1_strided_dispatch_vector[td_ns::num_types]; - -void populate_expm1_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = expm1_fn_ns; - - using fn_ns::Expm1ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(expm1_contig_dispatch_vector); - - using fn_ns::Expm1StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(expm1_strided_dispatch_vector); - - using fn_ns::Expm1TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(expm1_output_typeid_vector); -} - -} // namespace impl - -// U15: ==== FLOOR (x) -namespace impl -{ - -namespace floor_fn_ns = dpctl::tensor::kernels::floor; - -static unary_contig_impl_fn_ptr_t - floor_contig_dispatch_vector[td_ns::num_types]; -static int floor_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - floor_strided_dispatch_vector[td_ns::num_types]; - -void populate_floor_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = floor_fn_ns; - - using fn_ns::FloorContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(floor_contig_dispatch_vector); - - using fn_ns::FloorStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(floor_strided_dispatch_vector); - - using fn_ns::FloorTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(floor_output_typeid_vector); -} - -} // namespace impl - -// B10: ==== FLOOR_DIVIDE (x1, x2) -namespace impl -{ -namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; - -static binary_contig_impl_fn_ptr_t - floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - floor_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - floor_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_floor_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = floor_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::FloorDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(floor_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::FloorDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(floor_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::FloorDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(floor_divide_contig_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::FloorDivideInplaceStridedFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(floor_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::FloorDivideInplaceContigFactory; - DispatchTableBuilder - dtb5; - dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); -}; - -} // namespace impl - -// B11: ==== GREATER (x1, x2) -namespace impl -{ -namespace greater_fn_ns = dpctl::tensor::kernels::greater; - -static binary_contig_impl_fn_ptr_t - greater_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_contig_dispatch_table); -}; -} // namespace impl - -// B12: ==== GREATER_EQUAL (x1, x2) -namespace impl -{ -namespace greater_equal_fn_ns = dpctl::tensor::kernels::greater_equal; - -static binary_contig_impl_fn_ptr_t - greater_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_equal_contig_dispatch_table); -}; -} // namespace impl - -// U16: ==== IMAG (x) -namespace impl -{ - -namespace imag_fn_ns = dpctl::tensor::kernels::imag; - -static unary_contig_impl_fn_ptr_t imag_contig_dispatch_vector[td_ns::num_types]; -static int imag_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - imag_strided_dispatch_vector[td_ns::num_types]; - -void populate_imag_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = imag_fn_ns; - - using fn_ns::ImagContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(imag_contig_dispatch_vector); - - using fn_ns::ImagStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(imag_strided_dispatch_vector); - - using fn_ns::ImagTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(imag_output_typeid_vector); -} -} // namespace impl - -// U17: ==== ISFINITE (x) -namespace impl -{ -namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; - -static unary_contig_impl_fn_ptr_t - isfinite_contig_dispatch_vector[td_ns::num_types]; -static int isfinite_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isfinite_strided_dispatch_vector[td_ns::num_types]; - -void populate_isfinite_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isfinite_fn_ns; - - using fn_ns::IsFiniteContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); - - using fn_ns::IsFiniteStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); - - using fn_ns::IsFiniteTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isfinite_output_typeid_vector); -} - -} // namespace impl - -// U18: ==== ISINF (x) -namespace impl -{ -namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; - -static unary_contig_impl_fn_ptr_t - isinf_contig_dispatch_vector[td_ns::num_types]; -static int isinf_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isinf_strided_dispatch_vector[td_ns::num_types]; - -void populate_isinf_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isinf_fn_ns; - - using fn_ns::IsInfContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); - - using fn_ns::IsInfStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); - - using fn_ns::IsInfTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isinf_output_typeid_vector); -} - -} // namespace impl - -// U19: ==== ISNAN (x) -namespace impl -{ -namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; - -static unary_contig_impl_fn_ptr_t - isnan_contig_dispatch_vector[td_ns::num_types]; -static int isnan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isnan_strided_dispatch_vector[td_ns::num_types]; - -void populate_isnan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isnan_fn_ns; - - using fn_ns::IsNanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); - - using fn_ns::IsNanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); - - using fn_ns::IsNanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isnan_output_typeid_vector); -} - -} // namespace impl - -// B13: ==== LESS (x1, x2) -namespace impl -{ -namespace less_fn_ns = dpctl::tensor::kernels::less; - -static binary_contig_impl_fn_ptr_t less_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int less_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_contig_dispatch_table); -}; -} // namespace impl - -// B14: ==== LESS_EQUAL (x1, x2) -namespace impl -{ -namespace less_equal_fn_ns = dpctl::tensor::kernels::less_equal; - -static binary_contig_impl_fn_ptr_t - less_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int less_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_equal_contig_dispatch_table); -}; -} // namespace impl - -// U20: ==== LOG (x) -namespace impl -{ - -namespace log_fn_ns = dpctl::tensor::kernels::log; - -static unary_contig_impl_fn_ptr_t log_contig_dispatch_vector[td_ns::num_types]; -static int log_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log_strided_dispatch_vector[td_ns::num_types]; - -void populate_log_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log_fn_ns; - - using fn_ns::LogContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log_contig_dispatch_vector); - - using fn_ns::LogStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log_strided_dispatch_vector); - - using fn_ns::LogTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log_output_typeid_vector); -} - -} // namespace impl - -// U21: ==== LOG1P (x) -namespace impl -{ - -namespace log1p_fn_ns = dpctl::tensor::kernels::log1p; - -static unary_contig_impl_fn_ptr_t - log1p_contig_dispatch_vector[td_ns::num_types]; -static int log1p_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log1p_strided_dispatch_vector[td_ns::num_types]; - -void populate_log1p_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log1p_fn_ns; - - using fn_ns::Log1pContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log1p_contig_dispatch_vector); - - using fn_ns::Log1pStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log1p_strided_dispatch_vector); - - using fn_ns::Log1pTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log1p_output_typeid_vector); -} - -} // namespace impl - -// U22: ==== LOG2 (x) -namespace impl -{ - -namespace log2_fn_ns = dpctl::tensor::kernels::log2; - -static unary_contig_impl_fn_ptr_t log2_contig_dispatch_vector[td_ns::num_types]; -static int log2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log2_strided_dispatch_vector[td_ns::num_types]; - -void populate_log2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log2_fn_ns; - - using fn_ns::Log2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log2_contig_dispatch_vector); - - using fn_ns::Log2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log2_strided_dispatch_vector); - - using fn_ns::Log2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log2_output_typeid_vector); -}; - -} // namespace impl - -// U23: ==== LOG10 (x) -namespace impl -{ - -namespace log10_fn_ns = dpctl::tensor::kernels::log10; - -static unary_contig_impl_fn_ptr_t - log10_contig_dispatch_vector[td_ns::num_types]; -static int log10_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log10_strided_dispatch_vector[td_ns::num_types]; - -void populate_log10_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log10_fn_ns; - - using fn_ns::Log10ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log10_contig_dispatch_vector); - - using fn_ns::Log10StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log10_strided_dispatch_vector); - - using fn_ns::Log10TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log10_output_typeid_vector); -}; - -} // namespace impl - -// B15: ==== LOGADDEXP (x1, x2) -namespace impl -{ -namespace logaddexp_fn_ns = dpctl::tensor::kernels::logaddexp; - -static binary_contig_impl_fn_ptr_t - logaddexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logaddexp_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logaddexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logaddexp_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logaddexp_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogAddExpTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logaddexp_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogAddExpStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logaddexp_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogAddExpContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logaddexp_contig_dispatch_table); -}; -} // namespace impl - -// B16: ==== LOGICAL_AND (x1, x2) -namespace impl -{ -namespace logical_and_fn_ns = dpctl::tensor::kernels::logical_and; - -static binary_contig_impl_fn_ptr_t - logical_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_and_contig_dispatch_table); -}; -} // namespace impl - -// U24: ==== LOGICAL_NOT (x) -namespace impl -{ -namespace logical_not_fn_ns = dpctl::tensor::kernels::logical_not; - -static unary_contig_impl_fn_ptr_t - logical_not_contig_dispatch_vector[td_ns::num_types]; -static int logical_not_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - logical_not_strided_dispatch_vector[td_ns::num_types]; - -void populate_logical_not_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = logical_not_fn_ns; - - using fn_ns::LogicalNotContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(logical_not_contig_dispatch_vector); - - using fn_ns::LogicalNotStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(logical_not_strided_dispatch_vector); - - using fn_ns::LogicalNotTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(logical_not_output_typeid_vector); -}; -} // namespace impl - -// B17: ==== LOGICAL_OR (x1, x2) -namespace impl -{ -namespace logical_or_fn_ns = dpctl::tensor::kernels::logical_or; - -static binary_contig_impl_fn_ptr_t - logical_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_or_contig_dispatch_table); -}; -} // namespace impl - -// B18: ==== LOGICAL_XOR (x1, x2) -namespace impl -{ -namespace logical_xor_fn_ns = dpctl::tensor::kernels::logical_xor; - -static binary_contig_impl_fn_ptr_t - logical_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_xor_contig_dispatch_table); -}; -} // namespace impl - -// B??: ==== MAXIMUM (x1, x2) -namespace impl -{ - -namespace maximum_fn_ns = dpctl::tensor::kernels::maximum; - -static binary_contig_impl_fn_ptr_t - maximum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int maximum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - maximum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_maximum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = maximum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MaximumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(maximum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MaximumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(maximum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MaximumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(maximum_contig_dispatch_table); -}; - -} // namespace impl - -// B??: ==== MINIMUM (x1, x2) -namespace impl -{ - -namespace minimum_fn_ns = dpctl::tensor::kernels::minimum; - -static binary_contig_impl_fn_ptr_t - minimum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int minimum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - minimum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_minimum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = minimum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MinimumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(minimum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MinimumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(minimum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MinimumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(minimum_contig_dispatch_table); -}; - -} // namespace impl - -// B19: ==== MULTIPLY (x1, x2) -namespace impl -{ - -namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; - -static binary_contig_impl_fn_ptr_t - multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// mul(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - multiply_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// mul(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - multiply_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - multiply_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - multiply_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - multiply_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_multiply_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = multiply_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MultiplyTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(multiply_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MultiplyStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(multiply_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MultiplyContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(multiply_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::MultiplyContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - MultiplyContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - multiply_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::MultiplyContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - MultiplyContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::MultiplyInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(multiply_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::MultiplyInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(multiply_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::MultiplyInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U25: ==== NEGATIVE (x) -namespace impl -{ - -namespace negative_fn_ns = dpctl::tensor::kernels::negative; - -static unary_contig_impl_fn_ptr_t - negative_contig_dispatch_vector[td_ns::num_types]; -static int negative_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - negative_strided_dispatch_vector[td_ns::num_types]; - -void populate_negative_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = negative_fn_ns; - - using fn_ns::NegativeContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(negative_contig_dispatch_vector); - - using fn_ns::NegativeStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(negative_strided_dispatch_vector); - - using fn_ns::NegativeTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(negative_output_typeid_vector); -} - -} // namespace impl - -// B20: ==== NOT_EQUAL (x1, x2) -namespace impl -{ -namespace not_equal_fn_ns = dpctl::tensor::kernels::not_equal; - -static binary_contig_impl_fn_ptr_t - not_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int not_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - not_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_not_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = not_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::NotEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(not_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::NotEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(not_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::NotEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(not_equal_contig_dispatch_table); -}; -} // namespace impl - -// U26: ==== POSITIVE (x) -namespace impl -{ - -namespace positive_fn_ns = dpctl::tensor::kernels::positive; - -static unary_contig_impl_fn_ptr_t - positive_contig_dispatch_vector[td_ns::num_types]; -static int positive_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - positive_strided_dispatch_vector[td_ns::num_types]; - -void populate_positive_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = positive_fn_ns; - - using fn_ns::PositiveContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(positive_contig_dispatch_vector); - - using fn_ns::PositiveStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(positive_strided_dispatch_vector); - - using fn_ns::PositiveTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(positive_output_typeid_vector); -} - -} // namespace impl - -// B21: ==== POW (x1, x2) -namespace impl -{ - -namespace pow_fn_ns = dpctl::tensor::kernels::pow; - -static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_pow_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = pow_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::PowTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(pow_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::PowStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(pow_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::PowContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(pow_contig_dispatch_table); -}; - -} // namespace impl - -// U??: ==== PROJ (x) -namespace impl -{ - -namespace proj_fn_ns = dpctl::tensor::kernels::proj; - -static unary_contig_impl_fn_ptr_t proj_contig_dispatch_vector[td_ns::num_types]; -static int proj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - proj_strided_dispatch_vector[td_ns::num_types]; - -void populate_proj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = proj_fn_ns; - - using fn_ns::ProjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(proj_contig_dispatch_vector); - - using fn_ns::ProjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(proj_strided_dispatch_vector); - - using fn_ns::ProjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(proj_output_typeid_vector); -} -} // namespace impl - -// U27: ==== REAL (x) -namespace impl -{ - -namespace real_fn_ns = dpctl::tensor::kernels::real; - -static unary_contig_impl_fn_ptr_t real_contig_dispatch_vector[td_ns::num_types]; -static int real_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - real_strided_dispatch_vector[td_ns::num_types]; - -void populate_real_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = real_fn_ns; - - using fn_ns::RealContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(real_contig_dispatch_vector); - - using fn_ns::RealStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(real_strided_dispatch_vector); - - using fn_ns::RealTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(real_output_typeid_vector); -} -} // namespace impl - -// B22: ==== REMAINDER (x1, x2) -namespace impl -{ - -namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; - -static binary_contig_impl_fn_ptr_t - remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_remainder_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = remainder_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::RemainderTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(remainder_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::RemainderStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(remainder_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::RemainderContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(remainder_contig_dispatch_table); -} - -} // namespace impl - -// U28: ==== ROUND (x) -namespace impl -{ - -namespace round_fn_ns = dpctl::tensor::kernels::round; - -static unary_contig_impl_fn_ptr_t - round_contig_dispatch_vector[td_ns::num_types]; -static int round_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - round_strided_dispatch_vector[td_ns::num_types]; - -void populate_round_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = round_fn_ns; - - using fn_ns::RoundContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(round_contig_dispatch_vector); - - using fn_ns::RoundStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(round_strided_dispatch_vector); - - using fn_ns::RoundTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(round_output_typeid_vector); -} - -} // namespace impl - -// U29: ==== SIGN (x) -namespace impl -{ - -namespace sign_fn_ns = dpctl::tensor::kernels::sign; - -static unary_contig_impl_fn_ptr_t sign_contig_dispatch_vector[td_ns::num_types]; -static int sign_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sign_strided_dispatch_vector[td_ns::num_types]; - -void populate_sign_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sign_fn_ns; - - using fn_ns::SignContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sign_contig_dispatch_vector); - - using fn_ns::SignStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sign_strided_dispatch_vector); - - using fn_ns::SignTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sign_output_typeid_vector); -} - -} // namespace impl - -// ==== SIGNBIT (x) -namespace impl -{ - -namespace signbit_fn_ns = dpctl::tensor::kernels::signbit; - -static unary_contig_impl_fn_ptr_t - signbit_contig_dispatch_vector[td_ns::num_types]; -static int signbit_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - signbit_strided_dispatch_vector[td_ns::num_types]; - -void populate_signbit_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = signbit_fn_ns; - - using fn_ns::SignbitContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(signbit_contig_dispatch_vector); - - using fn_ns::SignbitStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(signbit_strided_dispatch_vector); - - using fn_ns::SignbitTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(signbit_output_typeid_vector); -} - -} // namespace impl - -// U30: ==== SIN (x) -namespace impl -{ - -namespace sin_fn_ns = dpctl::tensor::kernels::sin; - -static unary_contig_impl_fn_ptr_t sin_contig_dispatch_vector[td_ns::num_types]; -static int sin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sin_strided_dispatch_vector[td_ns::num_types]; - -void populate_sin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sin_fn_ns; - - using fn_ns::SinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sin_contig_dispatch_vector); - - using fn_ns::SinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sin_strided_dispatch_vector); - - using fn_ns::SinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sin_output_typeid_vector); -} - -} // namespace impl - -// U31: ==== SINH (x) -namespace impl -{ - -namespace sinh_fn_ns = dpctl::tensor::kernels::sinh; - -static unary_contig_impl_fn_ptr_t sinh_contig_dispatch_vector[td_ns::num_types]; -static int sinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_sinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sinh_fn_ns; - - using fn_ns::SinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sinh_contig_dispatch_vector); - - using fn_ns::SinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sinh_strided_dispatch_vector); - - using fn_ns::SinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sinh_output_typeid_vector); -} - -} // namespace impl - -// U32: ==== SQUARE (x) -namespace impl -{ - -namespace square_fn_ns = dpctl::tensor::kernels::square; - -static unary_contig_impl_fn_ptr_t - square_contig_dispatch_vector[td_ns::num_types]; -static int square_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - square_strided_dispatch_vector[td_ns::num_types]; - -void populate_square_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = square_fn_ns; - - using fn_ns::SquareContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(square_contig_dispatch_vector); - - using fn_ns::SquareStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(square_strided_dispatch_vector); - - using fn_ns::SquareTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(square_output_typeid_vector); -} - -} // namespace impl - -// U33: ==== SQRT (x) -namespace impl -{ - -namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; - -static unary_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; -static int sqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_sqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sqrt_fn_ns; - - using fn_ns::SqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); - - using fn_ns::SqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); - - using fn_ns::SqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); -} - -} // namespace impl - -// B23: ==== SUBTRACT (x1, x2) -namespace impl -{ -namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; - -static binary_contig_impl_fn_ptr_t - subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// sub(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - subtract_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// sub(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - subtract_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - subtract_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - subtract_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - subtract_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_subtract_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = subtract_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::SubtractTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(subtract_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::SubtractStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(subtract_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::SubtractContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(subtract_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::SubtractContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - SubtractContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - subtract_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::SubtractContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - SubtractContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::SubtractInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(subtract_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::SubtractInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(subtract_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::SubtractInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U34: ==== TAN (x) -namespace impl -{ - -namespace tan_fn_ns = dpctl::tensor::kernels::tan; - -static unary_contig_impl_fn_ptr_t tan_contig_dispatch_vector[td_ns::num_types]; -static int tan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tan_strided_dispatch_vector[td_ns::num_types]; - -void populate_tan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tan_fn_ns; - - using fn_ns::TanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tan_contig_dispatch_vector); - - using fn_ns::TanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tan_strided_dispatch_vector); - - using fn_ns::TanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tan_output_typeid_vector); -} - -} // namespace impl - -// U35: ==== TANH (x) -namespace impl -{ - -namespace tanh_fn_ns = dpctl::tensor::kernels::tanh; - -static unary_contig_impl_fn_ptr_t tanh_contig_dispatch_vector[td_ns::num_types]; -static int tanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_tanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tanh_fn_ns; - - using fn_ns::TanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tanh_contig_dispatch_vector); - - using fn_ns::TanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tanh_strided_dispatch_vector); - - using fn_ns::TanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tanh_output_typeid_vector); -} - -} // namespace impl - -// U36: ==== TRUNC (x) -namespace impl -{ - -namespace trunc_fn_ns = dpctl::tensor::kernels::trunc; - -static unary_contig_impl_fn_ptr_t - trunc_contig_dispatch_vector[td_ns::num_types]; -static int trunc_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - trunc_strided_dispatch_vector[td_ns::num_types]; - -void populate_trunc_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = trunc_fn_ns; - - using fn_ns::TruncContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(trunc_contig_dispatch_vector); - - using fn_ns::TruncStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(trunc_strided_dispatch_vector); - - using fn_ns::TruncTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(trunc_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== HYPOT (x1, x2) -namespace impl -{ -namespace hypot_fn_ns = dpctl::tensor::kernels::hypot; - -static binary_contig_impl_fn_ptr_t - hypot_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int hypot_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - hypot_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_hypot_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = hypot_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::HypotTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(hypot_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::HypotStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(hypot_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::HypotContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(hypot_contig_dispatch_table); -}; - -} // namespace impl - -// U37: ==== CBRT (x) -namespace impl -{ - -namespace cbrt_fn_ns = dpctl::tensor::kernels::cbrt; - -static unary_contig_impl_fn_ptr_t cbrt_contig_dispatch_vector[td_ns::num_types]; -static int cbrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cbrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_cbrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cbrt_fn_ns; - - using fn_ns::CbrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cbrt_contig_dispatch_vector); - - using fn_ns::CbrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cbrt_strided_dispatch_vector); - - using fn_ns::CbrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cbrt_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== COPYSIGN (x1, x2) -namespace impl -{ -namespace copysign_fn_ns = dpctl::tensor::kernels::copysign; - -static binary_contig_impl_fn_ptr_t - copysign_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int copysign_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - copysign_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_copysign_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = copysign_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::CopysignTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(copysign_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::CopysignStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(copysign_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::CopysignContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(copysign_contig_dispatch_table); -}; - -} // namespace impl - -// U38: ==== EXP2 (x) -namespace impl -{ - -namespace exp2_fn_ns = dpctl::tensor::kernels::exp2; - -static unary_contig_impl_fn_ptr_t exp2_contig_dispatch_vector[td_ns::num_types]; -static int exp2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp2_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp2_fn_ns; - - using fn_ns::Exp2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp2_contig_dispatch_vector); - - using fn_ns::Exp2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp2_strided_dispatch_vector); - - using fn_ns::Exp2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp2_output_typeid_vector); -} - -} // namespace impl - -// U39: ==== RSQRT (x) -namespace impl -{ - -namespace rsqrt_fn_ns = dpctl::tensor::kernels::rsqrt; - -static unary_contig_impl_fn_ptr_t - rsqrt_contig_dispatch_vector[td_ns::num_types]; -static int rsqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - rsqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_rsqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = rsqrt_fn_ns; - - using fn_ns::RsqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(rsqrt_contig_dispatch_vector); - - using fn_ns::RsqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(rsqrt_strided_dispatch_vector); - - using fn_ns::RsqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(rsqrt_output_typeid_vector); -} - -} // namespace impl - -// ========================================================================================== -// // - -namespace py = pybind11; - -void init_elementwise_functions(py::module_ m) -{ - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - // U01: ==== ABS (x) - { - impl::populate_abs_dispatch_vectors(); - using impl::abs_contig_dispatch_vector; - using impl::abs_output_typeid_vector; - using impl::abs_strided_dispatch_vector; - - auto abs_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, abs_output_typeid_vector, - abs_contig_dispatch_vector, abs_strided_dispatch_vector); - }; - m.def("_abs", abs_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto abs_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, abs_output_typeid_vector); - }; - m.def("_abs_result_type", abs_result_type_pyapi); - } - - // U02: ==== ACOS (x) - { - impl::populate_acos_dispatch_vectors(); - using impl::acos_contig_dispatch_vector; - using impl::acos_output_typeid_vector; - using impl::acos_strided_dispatch_vector; - - auto acos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acos_output_typeid_vector, - acos_contig_dispatch_vector, acos_strided_dispatch_vector); - }; - m.def("_acos", acos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, acos_output_typeid_vector); - }; - m.def("_acos_result_type", acos_result_type_pyapi); - } - - // U03: ===== ACOSH (x) - { - impl::populate_acosh_dispatch_vectors(); - using impl::acosh_contig_dispatch_vector; - using impl::acosh_output_typeid_vector; - using impl::acosh_strided_dispatch_vector; - - auto acosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acosh_output_typeid_vector, - acosh_contig_dispatch_vector, acosh_strided_dispatch_vector); - }; - m.def("_acosh", acosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - acosh_output_typeid_vector); - }; - m.def("_acosh_result_type", acosh_result_type_pyapi); - } - - // B01: ===== ADD (x1, x2) - { - impl::populate_add_dispatch_tables(); - using impl::add_contig_dispatch_table; - using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::add_output_id_table; - using impl::add_strided_dispatch_table; - - auto add_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, add_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - add_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - add_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto add_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - add_output_id_table); - }; - m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_add_result_type", add_result_type_pyapi, ""); - - using impl::add_inplace_contig_dispatch_table; - using impl::add_inplace_row_matrix_dispatch_table; - using impl::add_inplace_strided_dispatch_table; - - auto add_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, add_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - add_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - add_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - add_inplace_row_matrix_dispatch_table); - }; - m.def("_add_inplace", add_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U04: ===== ASIN (x) - { - impl::populate_asin_dispatch_vectors(); - using impl::asin_contig_dispatch_vector; - using impl::asin_output_typeid_vector; - using impl::asin_strided_dispatch_vector; - - auto asin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asin_output_typeid_vector, - asin_contig_dispatch_vector, asin_strided_dispatch_vector); - }; - m.def("_asin", asin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, asin_output_typeid_vector); - }; - m.def("_asin_result_type", asin_result_type_pyapi); - } - - // U05: ===== ASINH (x) - { - impl::populate_asinh_dispatch_vectors(); - using impl::asinh_contig_dispatch_vector; - using impl::asinh_output_typeid_vector; - using impl::asinh_strided_dispatch_vector; - - auto asinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asinh_output_typeid_vector, - asinh_contig_dispatch_vector, asinh_strided_dispatch_vector); - }; - m.def("_asinh", asinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - asinh_output_typeid_vector); - }; - m.def("_asinh_result_type", asinh_result_type_pyapi); - } - - // U06: ===== ATAN (x) - { - impl::populate_atan_dispatch_vectors(); - using impl::atan_contig_dispatch_vector; - using impl::atan_output_typeid_vector; - using impl::atan_strided_dispatch_vector; - - auto atan_pyapi = [&](arrayT src, arrayT dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atan_output_typeid_vector, - atan_contig_dispatch_vector, atan_strided_dispatch_vector); - }; - m.def("_atan", atan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atan_result_type_pyapi = [&](py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, atan_output_typeid_vector); - }; - m.def("_atan_result_type", atan_result_type_pyapi); - } - - // B02: ===== ATAN2 (x1, x2) - { - impl::populate_atan2_dispatch_tables(); - using impl::atan2_contig_dispatch_table; - using impl::atan2_output_id_table; - using impl::atan2_strided_dispatch_table; - - auto atan2_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, atan2_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - atan2_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - atan2_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto atan2_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - atan2_output_id_table); - }; - m.def("_atan2", atan2_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_atan2_result_type", atan2_result_type_pyapi, ""); - } - - // U07: ===== ATANH (x) - { - impl::populate_atanh_dispatch_vectors(); - using impl::atanh_contig_dispatch_vector; - using impl::atanh_output_typeid_vector; - using impl::atanh_strided_dispatch_vector; - - auto atanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atanh_output_typeid_vector, - atanh_contig_dispatch_vector, atanh_strided_dispatch_vector); - }; - m.def("_atanh", atanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - atanh_output_typeid_vector); - }; - m.def("_atanh_result_type", atanh_result_type_pyapi); - } - - // B03: ===== BITWISE_AND (x1, x2) - { - impl::populate_bitwise_and_dispatch_tables(); - using impl::bitwise_and_contig_dispatch_table; - using impl::bitwise_and_output_id_table; - using impl::bitwise_and_strided_dispatch_table; - - auto bitwise_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_and_output_id_table); - }; - m.def("_bitwise_and", bitwise_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); - } - - // B04: ===== BITWISE_LEFT_SHIFT (x1, x2) - { - impl::populate_bitwise_left_shift_dispatch_tables(); - using impl::bitwise_left_shift_contig_dispatch_table; - using impl::bitwise_left_shift_output_id_table; - using impl::bitwise_left_shift_strided_dispatch_table; - - auto bitwise_left_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_left_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_left_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_left_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_left_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_left_shift_output_id_table); - }; - m.def("_bitwise_left_shift", bitwise_left_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_left_shift_result_type", - bitwise_left_shift_result_type_pyapi, ""); - } - - // U08: ===== BITWISE_INVERT (x) - { - impl::populate_bitwise_invert_dispatch_vectors(); - using impl::bitwise_invert_contig_dispatch_vector; - using impl::bitwise_invert_output_typeid_vector; - using impl::bitwise_invert_strided_dispatch_vector; - - auto bitwise_invert_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - bitwise_invert_output_typeid_vector, - bitwise_invert_contig_dispatch_vector, - bitwise_invert_strided_dispatch_vector); - }; - m.def("_bitwise_invert", bitwise_invert_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto bitwise_invert_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type( - dtype, bitwise_invert_output_typeid_vector); - }; - m.def("_bitwise_invert_result_type", bitwise_invert_result_type_pyapi); - } - - // B05: ===== BITWISE_OR (x1, x2) - { - impl::populate_bitwise_or_dispatch_tables(); - using impl::bitwise_or_contig_dispatch_table; - using impl::bitwise_or_output_id_table; - using impl::bitwise_or_strided_dispatch_table; - - auto bitwise_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_or_output_id_table); - }; - m.def("_bitwise_or", bitwise_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); - } - - // B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) - { - impl::populate_bitwise_right_shift_dispatch_tables(); - using impl::bitwise_right_shift_contig_dispatch_table; - using impl::bitwise_right_shift_output_id_table; - using impl::bitwise_right_shift_strided_dispatch_table; - - auto bitwise_right_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_right_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_right_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_right_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_right_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_right_shift_output_id_table); - }; - m.def("_bitwise_right_shift", bitwise_right_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_right_shift_result_type", - bitwise_right_shift_result_type_pyapi, ""); - } - - // B07: ===== BITWISE_XOR (x1, x2) - { - impl::populate_bitwise_xor_dispatch_tables(); - using impl::bitwise_xor_contig_dispatch_table; - using impl::bitwise_xor_output_id_table; - using impl::bitwise_xor_strided_dispatch_table; - - auto bitwise_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_xor_output_id_table); - }; - m.def("_bitwise_xor", bitwise_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); - } - - // U09: ==== CEIL (x) - { - impl::populate_ceil_dispatch_vectors(); - using impl::ceil_contig_dispatch_vector; - using impl::ceil_output_typeid_vector; - using impl::ceil_strided_dispatch_vector; - - auto ceil_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, ceil_output_typeid_vector, - ceil_contig_dispatch_vector, ceil_strided_dispatch_vector); - }; - m.def("_ceil", ceil_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto ceil_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, ceil_output_typeid_vector); - }; - m.def("_ceil_result_type", ceil_result_type_pyapi); - } - - // U10: ==== CONJ (x) - { - impl::populate_conj_dispatch_vectors(); - using impl::conj_contig_dispatch_vector; - using impl::conj_output_typeid_vector; - using impl::conj_strided_dispatch_vector; - - auto conj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, conj_output_typeid_vector, - conj_contig_dispatch_vector, conj_strided_dispatch_vector); - }; - m.def("_conj", conj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto conj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, conj_output_typeid_vector); - }; - m.def("_conj_result_type", conj_result_type_pyapi); - } - - // U11: ==== COS (x) - { - impl::populate_cos_dispatch_vectors(); - using impl::cos_contig_dispatch_vector; - using impl::cos_output_typeid_vector; - using impl::cos_strided_dispatch_vector; - - auto cos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cos_output_typeid_vector, - cos_contig_dispatch_vector, cos_strided_dispatch_vector); - }; - m.def("_cos", cos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cos_output_typeid_vector); - }; - m.def("_cos_result_type", cos_result_type_pyapi); - } - - // U12: ==== COSH (x) - { - impl::populate_cosh_dispatch_vectors(); - using impl::cosh_contig_dispatch_vector; - using impl::cosh_output_typeid_vector; - using impl::cosh_strided_dispatch_vector; - - auto cosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cosh_output_typeid_vector, - cosh_contig_dispatch_vector, cosh_strided_dispatch_vector); - }; - m.def("_cosh", cosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cosh_output_typeid_vector); - }; - m.def("_cosh_result_type", cosh_result_type_pyapi); - } - - // B08: ==== DIVIDE (x1, x2) - { - impl::populate_true_divide_dispatch_tables(); - using impl::true_divide_contig_dispatch_table; - using impl:: - true_divide_contig_matrix_contig_row_broadcast_dispatch_table; - using impl:: - true_divide_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::true_divide_output_id_table; - using impl::true_divide_strided_dispatch_table; - - auto divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, true_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - true_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - true_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - true_divide_output_id_table); - }; - m.def("_divide", divide_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_divide_result_type", divide_result_type_pyapi, ""); - - using impl::true_divide_inplace_contig_dispatch_table; - using impl::true_divide_inplace_output_id_table; - using impl::true_divide_inplace_row_matrix_dispatch_table; - using impl::true_divide_inplace_strided_dispatch_table; - - auto divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, - true_divide_inplace_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - true_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - true_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - true_divide_inplace_row_matrix_dispatch_table); - }; - m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B09: ==== EQUAL (x1, x2) - { - impl::populate_equal_dispatch_tables(); - using impl::equal_contig_dispatch_table; - using impl::equal_output_id_table; - using impl::equal_strided_dispatch_table; - - auto equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - equal_output_id_table); - }; - m.def("_equal", equal_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_equal_result_type", equal_result_type_pyapi, ""); - } - - // U13: ==== EXP (x) - { - impl::populate_exp_dispatch_vectors(); - using impl::exp_contig_dispatch_vector; - using impl::exp_output_typeid_vector; - using impl::exp_strided_dispatch_vector; - - auto exp_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp_output_typeid_vector, - exp_contig_dispatch_vector, exp_strided_dispatch_vector); - }; - m.def("_exp", exp_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp_output_typeid_vector); - }; - m.def("_exp_result_type", exp_result_type_pyapi); - } - - // U14: ==== EXPM1 (x) - { - impl::populate_expm1_dispatch_vectors(); - using impl::expm1_contig_dispatch_vector; - using impl::expm1_output_typeid_vector; - using impl::expm1_strided_dispatch_vector; - - auto expm1_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, expm1_output_typeid_vector, - expm1_contig_dispatch_vector, expm1_strided_dispatch_vector); - }; - m.def("_expm1", expm1_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto expm1_result_type_pyapi = [&](const py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, - expm1_output_typeid_vector); - }; - m.def("_expm1_result_type", expm1_result_type_pyapi); - } - - // U15: ==== FLOOR (x) - { - impl::populate_floor_dispatch_vectors(); - using impl::floor_contig_dispatch_vector; - using impl::floor_output_typeid_vector; - using impl::floor_strided_dispatch_vector; - - auto floor_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, floor_output_typeid_vector, - floor_contig_dispatch_vector, floor_strided_dispatch_vector); - }; - m.def("_floor", floor_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto floor_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - floor_output_typeid_vector); - }; - m.def("_floor_result_type", floor_result_type_pyapi); - } - - // B10: ==== FLOOR_DIVIDE (x1, x2) - { - impl::populate_floor_divide_dispatch_tables(); - using impl::floor_divide_contig_dispatch_table; - using impl::floor_divide_output_id_table; - using impl::floor_divide_strided_dispatch_table; - - auto floor_divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - floor_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - floor_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto floor_divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - floor_divide_output_id_table); - }; - m.def("_floor_divide", floor_divide_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); - - using impl::floor_divide_inplace_contig_dispatch_table; - using impl::floor_divide_inplace_strided_dispatch_table; - - auto floor_divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - floor_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - floor_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; - m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", - py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B11: ==== GREATER (x1, x2) - { - impl::populate_greater_dispatch_tables(); - using impl::greater_contig_dispatch_table; - using impl::greater_output_id_table; - using impl::greater_strided_dispatch_table; - - auto greater_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_output_id_table); - }; - m.def("_greater", greater_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_result_type", greater_result_type_pyapi, ""); - } - - // B12: ==== GREATER_EQUAL (x1, x2) - { - impl::populate_greater_equal_dispatch_tables(); - using impl::greater_equal_contig_dispatch_table; - using impl::greater_equal_output_id_table; - using impl::greater_equal_strided_dispatch_table; - - auto greater_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_equal_output_id_table); - }; - m.def("_greater_equal", greater_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_equal_result_type", greater_equal_result_type_pyapi, - ""); - } - - // U16: ==== IMAG (x) - { - impl::populate_imag_dispatch_vectors(); - using impl::imag_contig_dispatch_vector; - using impl::imag_output_typeid_vector; - using impl::imag_strided_dispatch_vector; - - auto imag_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, imag_output_typeid_vector, - imag_contig_dispatch_vector, imag_strided_dispatch_vector); - }; - m.def("_imag", imag_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto imag_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, imag_output_typeid_vector); - }; - m.def("_imag_result_type", imag_result_type_pyapi); - } - - // U17: ==== ISFINITE (x) - { - impl::populate_isfinite_dispatch_vectors(); - - using impl::isfinite_contig_dispatch_vector; - using impl::isfinite_output_typeid_vector; - using impl::isfinite_strided_dispatch_vector; - auto isfinite_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - isfinite_output_typeid_vector, - isfinite_contig_dispatch_vector, - isfinite_strided_dispatch_vector); - }; - auto isfinite_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isfinite_output_typeid_vector); - }; - m.def("_isfinite", isfinite_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isfinite_result_type", isfinite_result_type_pyapi, ""); - } - - // U18: ==== ISINF (x) - { - impl::populate_isinf_dispatch_vectors(); - - using impl::isinf_contig_dispatch_vector; - using impl::isinf_output_typeid_vector; - using impl::isinf_strided_dispatch_vector; - auto isinf_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isinf_output_typeid_vector, - isinf_contig_dispatch_vector, isinf_strided_dispatch_vector); - }; - auto isinf_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isinf_output_typeid_vector); - }; - m.def("_isinf", isinf_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isinf_result_type", isinf_result_type_pyapi, ""); - } - - // U19: ==== ISNAN (x) - { - impl::populate_isnan_dispatch_vectors(); - - using impl::isnan_contig_dispatch_vector; - using impl::isnan_output_typeid_vector; - using impl::isnan_strided_dispatch_vector; - auto isnan_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isnan_output_typeid_vector, - isnan_contig_dispatch_vector, isnan_strided_dispatch_vector); - }; - auto isnan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isnan_output_typeid_vector); - }; - m.def("_isnan", isnan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isnan_result_type", isnan_result_type_pyapi, ""); - } - - // B13: ==== LESS (x1, x2) - { - impl::populate_less_dispatch_tables(); - using impl::less_contig_dispatch_table; - using impl::less_output_id_table; - using impl::less_strided_dispatch_table; - - auto less_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_output_id_table); - }; - m.def("_less", less_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_result_type", less_result_type_pyapi, ""); - } - - // B14: ==== LESS_EQUAL (x1, x2) - { - impl::populate_less_equal_dispatch_tables(); - using impl::less_equal_contig_dispatch_table; - using impl::less_equal_output_id_table; - using impl::less_equal_strided_dispatch_table; - - auto less_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_equal_output_id_table); - }; - m.def("_less_equal", less_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_equal_result_type", less_equal_result_type_pyapi, ""); - } - - // U20: ==== LOG (x) - { - impl::populate_log_dispatch_vectors(); - using impl::log_contig_dispatch_vector; - using impl::log_output_typeid_vector; - using impl::log_strided_dispatch_vector; - - auto log_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log_output_typeid_vector, - log_contig_dispatch_vector, log_strided_dispatch_vector); - }; - m.def("_log", log_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log_output_typeid_vector); - }; - m.def("_log_result_type", log_result_type_pyapi); - } - - // U21: ==== LOG1P (x) - { - impl::populate_log1p_dispatch_vectors(); - using impl::log1p_contig_dispatch_vector; - using impl::log1p_output_typeid_vector; - using impl::log1p_strided_dispatch_vector; - - auto log1p_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log1p_output_typeid_vector, - log1p_contig_dispatch_vector, log1p_strided_dispatch_vector); - }; - m.def("_log1p", log1p_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log1p_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log1p_output_typeid_vector); - }; - m.def("_log1p_result_type", log1p_result_type_pyapi); - } - - // U22: ==== LOG2 (x) - { - impl::populate_log2_dispatch_vectors(); - - using impl::log2_contig_dispatch_vector; - using impl::log2_output_typeid_vector; - using impl::log2_strided_dispatch_vector; - auto log2_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log2_output_typeid_vector, - log2_contig_dispatch_vector, log2_strided_dispatch_vector); - }; - auto log2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log2_output_typeid_vector); - }; - m.def("_log2", log2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log2_result_type", log2_result_type_pyapi, ""); - } - - // U23: ==== LOG10 (x) - { - impl::populate_log10_dispatch_vectors(); - - using impl::log10_contig_dispatch_vector; - using impl::log10_output_typeid_vector; - using impl::log10_strided_dispatch_vector; - auto log10_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log10_output_typeid_vector, - log10_contig_dispatch_vector, log10_strided_dispatch_vector); - }; - auto log10_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log10_output_typeid_vector); - }; - m.def("_log10", log10_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log10_result_type", log10_result_type_pyapi, ""); - } - - // B15: ==== LOGADDEXP (x1, x2) - { - impl::populate_logaddexp_dispatch_tables(); - using impl::logaddexp_contig_dispatch_table; - using impl::logaddexp_output_id_table; - using impl::logaddexp_strided_dispatch_table; - - auto logaddexp_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logaddexp_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logaddexp_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logaddexp_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logaddexp_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logaddexp_output_id_table); - }; - m.def("_logaddexp", logaddexp_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logaddexp_result_type", logaddexp_result_type_pyapi, ""); - } - - // B16: ==== LOGICAL_AND (x1, x2) - { - impl::populate_logical_and_dispatch_tables(); - using impl::logical_and_contig_dispatch_table; - using impl::logical_and_output_id_table; - using impl::logical_and_strided_dispatch_table; - - auto logical_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_and_output_id_table); - }; - m.def("_logical_and", logical_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_and_result_type", logical_and_result_type_pyapi, ""); - } - - // U24: ==== LOGICAL_NOT (x) - { - impl::populate_logical_not_dispatch_vectors(); - using impl::logical_not_contig_dispatch_vector; - using impl::logical_not_output_typeid_vector; - using impl::logical_not_strided_dispatch_vector; - - auto logical_not_pyapi = [&](const arrayT &src, arrayT dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - logical_not_output_typeid_vector, - logical_not_contig_dispatch_vector, - logical_not_strided_dispatch_vector); - }; - m.def("_logical_not", logical_not_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto logical_not_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - logical_not_output_typeid_vector); - }; - m.def("_logical_not_result_type", logical_not_result_type_pyapi); - } - - // B17: ==== LOGICAL_OR (x1, x2) - { - impl::populate_logical_or_dispatch_tables(); - using impl::logical_or_contig_dispatch_table; - using impl::logical_or_output_id_table; - using impl::logical_or_strided_dispatch_table; - - auto logical_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_or_output_id_table); - }; - m.def("_logical_or", logical_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_or_result_type", logical_or_result_type_pyapi, ""); - } - - // B18: ==== LOGICAL_XOR (x1, x2) - { - impl::populate_logical_xor_dispatch_tables(); - using impl::logical_xor_contig_dispatch_table; - using impl::logical_xor_output_id_table; - using impl::logical_xor_strided_dispatch_table; - - auto logical_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_xor_output_id_table); - }; - m.def("_logical_xor", logical_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_xor_result_type", logical_xor_result_type_pyapi, ""); - } - - // B??: ==== MAXIMUM (x1, x2) - { - impl::populate_maximum_dispatch_tables(); - using impl::maximum_contig_dispatch_table; - using impl::maximum_output_id_table; - using impl::maximum_strided_dispatch_table; - - auto maximum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, maximum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - maximum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - maximum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto maximum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - maximum_output_id_table); - }; - m.def("_maximum", maximum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_maximum_result_type", maximum_result_type_pyapi, ""); - } - - // B??: ==== MINIMUM (x1, x2) - { - impl::populate_minimum_dispatch_tables(); - using impl::minimum_contig_dispatch_table; - using impl::minimum_output_id_table; - using impl::minimum_strided_dispatch_table; - - auto minimum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, minimum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - minimum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - minimum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto minimum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - minimum_output_id_table); - }; - m.def("_minimum", minimum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_minimum_result_type", minimum_result_type_pyapi, ""); - } - - // B19: ==== MULTIPLY (x1, x2) - { - impl::populate_multiply_dispatch_tables(); - using impl::multiply_contig_dispatch_table; - using impl::multiply_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::multiply_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::multiply_output_id_table; - using impl::multiply_strided_dispatch_table; - - auto multiply_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - multiply_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - multiply_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto multiply_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - multiply_output_id_table); - }; - m.def("_multiply", multiply_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_multiply_result_type", multiply_result_type_pyapi, ""); - - using impl::multiply_inplace_contig_dispatch_table; - using impl::multiply_inplace_row_matrix_dispatch_table; - using impl::multiply_inplace_strided_dispatch_table; - - auto multiply_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - multiply_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - multiply_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - multiply_inplace_row_matrix_dispatch_table); - }; - m.def("_multiply_inplace", multiply_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U25: ==== NEGATIVE (x) - { - impl::populate_negative_dispatch_vectors(); - using impl::negative_contig_dispatch_vector; - using impl::negative_output_typeid_vector; - using impl::negative_strided_dispatch_vector; - - auto negative_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - negative_output_typeid_vector, - negative_contig_dispatch_vector, - negative_strided_dispatch_vector); - }; - m.def("_negative", negative_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto negative_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - negative_output_typeid_vector); - }; - m.def("_negative_result_type", negative_result_type_pyapi); - } - - // B20: ==== NOT_EQUAL (x1, x2) - { - impl::populate_not_equal_dispatch_tables(); - using impl::not_equal_contig_dispatch_table; - using impl::not_equal_output_id_table; - using impl::not_equal_strided_dispatch_table; - - auto not_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, not_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - not_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - not_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto not_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - not_equal_output_id_table); - }; - m.def("_not_equal", not_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_not_equal_result_type", not_equal_result_type_pyapi, ""); - } - - // U26: ==== POSITIVE (x) - { - impl::populate_positive_dispatch_vectors(); - using impl::positive_contig_dispatch_vector; - using impl::positive_output_typeid_vector; - using impl::positive_strided_dispatch_vector; - - auto positive_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - positive_output_typeid_vector, - positive_contig_dispatch_vector, - positive_strided_dispatch_vector); - }; - m.def("_positive", positive_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto positive_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - positive_output_typeid_vector); - }; - m.def("_positive_result_type", positive_result_type_pyapi); - } - - // B21: ==== POW (x1, x2) - { - impl::populate_pow_dispatch_tables(); - using impl::pow_contig_dispatch_table; - using impl::pow_output_id_table; - using impl::pow_strided_dispatch_table; - - auto pow_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, pow_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - pow_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - pow_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto pow_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - pow_output_id_table); - }; - m.def("_pow", pow_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_pow_result_type", pow_result_type_pyapi, ""); - } - - // U??: ==== PROJ (x) - { - impl::populate_proj_dispatch_vectors(); - using impl::proj_contig_dispatch_vector; - using impl::proj_output_typeid_vector; - using impl::proj_strided_dispatch_vector; - - auto proj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, proj_output_typeid_vector, - proj_contig_dispatch_vector, proj_strided_dispatch_vector); - }; - m.def("_proj", proj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto proj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, proj_output_typeid_vector); - }; - m.def("_proj_result_type", proj_result_type_pyapi); - } - - // U27: ==== REAL (x) - { - impl::populate_real_dispatch_vectors(); - using impl::real_contig_dispatch_vector; - using impl::real_output_typeid_vector; - using impl::real_strided_dispatch_vector; - - auto real_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, real_output_typeid_vector, - real_contig_dispatch_vector, real_strided_dispatch_vector); - }; - m.def("_real", real_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto real_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, real_output_typeid_vector); - }; - m.def("_real_result_type", real_result_type_pyapi); - } - - // B22: ==== REMAINDER (x1, x2) - { - impl::populate_remainder_dispatch_tables(); - using impl::remainder_contig_dispatch_table; - using impl::remainder_output_id_table; - using impl::remainder_strided_dispatch_table; - - auto remainder_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, remainder_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - remainder_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - remainder_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto remainder_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - remainder_output_id_table); - }; - m.def("_remainder", remainder_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_remainder_result_type", remainder_result_type_pyapi, ""); - } - - // U28: ==== ROUND (x) - { - impl::populate_round_dispatch_vectors(); - using impl::round_contig_dispatch_vector; - using impl::round_output_typeid_vector; - using impl::round_strided_dispatch_vector; - - auto round_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, round_output_typeid_vector, - round_contig_dispatch_vector, round_strided_dispatch_vector); - }; - m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto round_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - round_output_typeid_vector); - }; - m.def("_round_result_type", round_result_type_pyapi); - } - - // U29: ==== SIGN (x) - { - impl::populate_sign_dispatch_vectors(); - using impl::sign_contig_dispatch_vector; - using impl::sign_output_typeid_vector; - using impl::sign_strided_dispatch_vector; - - auto sign_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sign_output_typeid_vector, - sign_contig_dispatch_vector, sign_strided_dispatch_vector); - }; - m.def("_sign", sign_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sign_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sign_output_typeid_vector); - }; - m.def("_sign_result_type", sign_result_type_pyapi); - } - - // ==== SIGNBIT (x) - { - impl::populate_signbit_dispatch_vectors(); - using impl::signbit_contig_dispatch_vector; - using impl::signbit_output_typeid_vector; - using impl::signbit_strided_dispatch_vector; - - auto signbit_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - signbit_output_typeid_vector, - signbit_contig_dispatch_vector, - signbit_strided_dispatch_vector); - }; - m.def("_signbit", signbit_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto signbit_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - signbit_output_typeid_vector); - }; - m.def("_signbit_result_type", signbit_result_type_pyapi); - } - - // U30: ==== SIN (x) - { - impl::populate_sin_dispatch_vectors(); - using impl::sin_contig_dispatch_vector; - using impl::sin_output_typeid_vector; - using impl::sin_strided_dispatch_vector; - - auto sin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sin_output_typeid_vector, - sin_contig_dispatch_vector, sin_strided_dispatch_vector); - }; - m.def("_sin", sin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sin_output_typeid_vector); - }; - m.def("_sin_result_type", sin_result_type_pyapi); - } - // U31: ==== SINH (x) - { - impl::populate_sinh_dispatch_vectors(); - using impl::sinh_contig_dispatch_vector; - using impl::sinh_output_typeid_vector; - using impl::sinh_strided_dispatch_vector; - - auto sinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sinh_output_typeid_vector, - sinh_contig_dispatch_vector, sinh_strided_dispatch_vector); - }; - m.def("_sinh", sinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sinh_output_typeid_vector); - }; - m.def("_sinh_result_type", sinh_result_type_pyapi); - } - - // U32: ==== SQUARE (x) - { - impl::populate_square_dispatch_vectors(); - using impl::square_contig_dispatch_vector; - using impl::square_output_typeid_vector; - using impl::square_strided_dispatch_vector; - - auto square_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, square_output_typeid_vector, - square_contig_dispatch_vector, square_strided_dispatch_vector); - }; - m.def("_square", square_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto square_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - square_output_typeid_vector); - }; - m.def("_square_result_type", square_result_type_pyapi); - } - - // U33: ==== SQRT (x) - { - impl::populate_sqrt_dispatch_vectors(); - using impl::sqrt_contig_dispatch_vector; - using impl::sqrt_output_typeid_vector; - using impl::sqrt_strided_dispatch_vector; - - auto sqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sqrt_output_typeid_vector, - sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); - }; - m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); - }; - m.def("_sqrt_result_type", sqrt_result_type_pyapi); - } - - // B23: ==== SUBTRACT (x1, x2) - { - impl::populate_subtract_dispatch_tables(); - using impl::subtract_contig_dispatch_table; - using impl::subtract_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::subtract_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::subtract_output_id_table; - using impl::subtract_strided_dispatch_table; - - auto subtract_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - subtract_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - subtract_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto subtract_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - subtract_output_id_table); - }; - m.def("_subtract", subtract_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_subtract_result_type", subtract_result_type_pyapi, ""); - - using impl::subtract_inplace_contig_dispatch_table; - using impl::subtract_inplace_row_matrix_dispatch_table; - using impl::subtract_inplace_strided_dispatch_table; - - auto subtract_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - subtract_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - subtract_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - subtract_inplace_row_matrix_dispatch_table); - }; - m.def("_subtract_inplace", subtract_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U34: ==== TAN (x) - { - impl::populate_tan_dispatch_vectors(); - using impl::tan_contig_dispatch_vector; - using impl::tan_output_typeid_vector; - using impl::tan_strided_dispatch_vector; - - auto tan_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tan_output_typeid_vector, - tan_contig_dispatch_vector, tan_strided_dispatch_vector); - }; - m.def("_tan", tan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tan_output_typeid_vector); - }; - m.def("_tan_result_type", tan_result_type_pyapi); - } - - // U35: ==== TANH (x) - { - impl::populate_tanh_dispatch_vectors(); - using impl::tanh_contig_dispatch_vector; - using impl::tanh_output_typeid_vector; - using impl::tanh_strided_dispatch_vector; - - auto tanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tanh_output_typeid_vector, - tanh_contig_dispatch_vector, tanh_strided_dispatch_vector); - }; - m.def("_tanh", tanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tanh_output_typeid_vector); - }; - m.def("_tanh_result_type", tanh_result_type_pyapi); - } - - // U36: ==== TRUNC (x) - { - impl::populate_trunc_dispatch_vectors(); - using impl::trunc_contig_dispatch_vector; - using impl::trunc_output_typeid_vector; - using impl::trunc_strided_dispatch_vector; - - auto trunc_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, trunc_output_typeid_vector, - trunc_contig_dispatch_vector, trunc_strided_dispatch_vector); - }; - m.def("_trunc", trunc_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto trunc_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - trunc_output_typeid_vector); - }; - m.def("_trunc_result_type", trunc_result_type_pyapi); - } - - // B24: ==== HYPOT (x1, x2) - { - impl::populate_hypot_dispatch_tables(); - using impl::hypot_contig_dispatch_table; - using impl::hypot_output_id_table; - using impl::hypot_strided_dispatch_table; - - auto hypot_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, hypot_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - hypot_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - hypot_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto hypot_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - hypot_output_id_table); - }; - m.def("_hypot", hypot_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_hypot_result_type", hypot_result_type_pyapi, ""); - } - - // U37: ==== CBRT (x) - { - impl::populate_cbrt_dispatch_vectors(); - using impl::cbrt_contig_dispatch_vector; - using impl::cbrt_output_typeid_vector; - using impl::cbrt_strided_dispatch_vector; - - auto cbrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cbrt_output_typeid_vector, - cbrt_contig_dispatch_vector, cbrt_strided_dispatch_vector); - }; - m.def("_cbrt", cbrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cbrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cbrt_output_typeid_vector); - }; - m.def("_cbrt_result_type", cbrt_result_type_pyapi); - } - - // B25: ==== COPYSIGN (x1, x2) - { - impl::populate_copysign_dispatch_tables(); - using impl::copysign_contig_dispatch_table; - using impl::copysign_output_id_table; - using impl::copysign_strided_dispatch_table; - - auto copysign_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, copysign_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - copysign_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - copysign_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto copysign_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - copysign_output_id_table); - }; - m.def("_copysign", copysign_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_copysign_result_type", copysign_result_type_pyapi, ""); - } - - // U38: ==== EXP2 (x) - { - impl::populate_exp2_dispatch_vectors(); - using impl::exp2_contig_dispatch_vector; - using impl::exp2_output_typeid_vector; - using impl::exp2_strided_dispatch_vector; - - auto exp2_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp2_output_typeid_vector, - exp2_contig_dispatch_vector, exp2_strided_dispatch_vector); - }; - m.def("_exp2", exp2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp2_output_typeid_vector); - }; - m.def("_exp2_result_type", exp2_result_type_pyapi); - } - - // U39: ==== RSQRT (x) - { - impl::populate_rsqrt_dispatch_vectors(); - using impl::rsqrt_contig_dispatch_vector; - using impl::rsqrt_output_typeid_vector; - using impl::rsqrt_strided_dispatch_vector; - - auto rsqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, rsqrt_output_typeid_vector, - rsqrt_contig_dispatch_vector, rsqrt_strided_dispatch_vector); - }; - m.def("_rsqrt", rsqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto rsqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - rsqrt_output_typeid_vector); - }; - m.def("_rsqrt_result_type", rsqrt_result_type_pyapi); - } -} - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/abs.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/abs.cpp index 4b3e8b635b..fd65860690 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/abs.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/abs.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "abs.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/acos.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/acos.cpp index 011cc052fb..38ddeba9b4 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/acos.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/acos.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "acos.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/acosh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/acosh.cpp index 526bd44f12..48a1036528 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/acosh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/acosh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "acosh.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp index 247b8e0283..e4ca013223 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/add.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "add.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/angle.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/angle.cpp index 166b37b27b..ab1d9ed866 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/angle.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/angle.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "angle.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/asin.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/asin.cpp index 14ef5e2665..1659ff6c30 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/asin.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/asin.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "asin.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/asinh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/asinh.cpp index dd0b4e62f7..f07ecc7c74 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/asinh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/asinh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "asinh.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/atan.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/atan.cpp index 81ff00c46a..a06ee1278a 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/atan.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/atan.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "atan.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/atan2.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/atan2.cpp index d12a4ff540..49ec146ace 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/atan2.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/atan2.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "atan2.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/atanh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/atanh.cpp index c42769b8d0..d97d78f79e 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/atanh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/atanh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "atanh.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp index f86f5112cd..ec227b8b0d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_and.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_invert.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_invert.cpp index 29a04cff38..a9015b213f 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_invert.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_invert.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_invert.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp index 7969bc4ffa..e3056e4dbf 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_left_shift.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp index 33a57f907c..81f976f862 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_or.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp index 3847204b1f..a0671256ce 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_right_shift.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp index 71d606766f..efe1e9bda5 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "bitwise_xor.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/cbrt.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/cbrt.cpp index b42f234c0d..6841023c45 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/cbrt.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/cbrt.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "cbrt.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/ceil.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/ceil.cpp index f1bb362c5b..eeb4666959 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/ceil.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/ceil.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "ceil.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/conj.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/conj.cpp index cac84e63fb..a520f2ce1f 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/conj.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/conj.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "conj.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/copysign.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/copysign.cpp index 6a887e0345..d02438f1fe 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/copysign.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/copysign.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "copysign.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/cos.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/cos.cpp index 1986610510..d8d1958f62 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/cos.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/cos.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "cos.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/cosh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/cosh.cpp index 0bb74df979..6525ad54fe 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/cosh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/cosh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "cosh.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp index da0137fd5f..673af04b77 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp @@ -34,6 +34,7 @@ #include "elementwise_functions_type_utils.hpp" #include "kernels/alignment.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "simplify_iteration_space.hpp" #include "utils/memory_overlap.hpp" #include "utils/offset_utils.hpp" @@ -42,6 +43,8 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +static_assert(std::is_same_v); + namespace dpctl { namespace tensor diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp index 473048e8fa..44b83497e8 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp @@ -24,9 +24,9 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include +#include #include "elementwise_functions_type_utils.hpp" #include "utils/type_dispatch.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.hpp b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.hpp index 6dac195dc2..7f1cacdc20 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions_type_utils.hpp @@ -25,9 +25,9 @@ #pragma once #include "dpctl4pybind11.hpp" -#include #include #include +#include #include "utils/type_dispatch.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/equal.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/equal.cpp index f36ec1b446..a650d5d8fd 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/equal.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/equal.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/exp.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/exp.cpp index 51ccaaac70..f0c6ec9a62 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/exp.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/exp.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/exp2.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/exp2.cpp index 438ad0800e..a59f193644 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/exp2.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/exp2.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/expm1.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/expm1.cpp index 3b9332c4f1..26c11a926b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/expm1.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/expm1.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/floor.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/floor.cpp index 9ccf89f13a..c538cd7668 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/floor.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/floor.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp index e75fc56c67..4797198483 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/floor_divide.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/greater.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/greater.cpp index f79102df47..87589a88f9 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/greater.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/greater.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/greater_equal.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/greater_equal.cpp index 005679c3fb..bb46ceb0ec 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/greater_equal.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/greater_equal.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/hypot.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/hypot.cpp index 2442710198..b14f23ea7c 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/hypot.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/hypot.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/imag.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/imag.cpp index 4012b9206f..270504a199 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/imag.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/imag.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/isfinite.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/isfinite.cpp index 73a2be4010..6da365d5e0 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/isfinite.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/isfinite.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/isinf.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/isinf.cpp index 2600fe4f74..1c19a3587d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/isinf.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/isinf.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/isnan.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/isnan.cpp index b75618c5e0..e2b224bd5e 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/isnan.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/isnan.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/less.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/less.cpp index c34122d862..1326b9741f 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/less.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/less.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/less_equal.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/less_equal.cpp index 712b30d902..f402ad88ad 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/less_equal.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/less_equal.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/log.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/log.cpp index f73b9e2414..5258f56158 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/log.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/log.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/log10.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/log10.cpp index 566dfcbcf7..d6a2815cd1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/log10.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/log10.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/log1p.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/log1p.cpp index badb474778..961e56e319 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/log1p.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/log1p.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/log2.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/log2.cpp index b5a8a39684..c307246ecc 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/log2.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/log2.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/logaddexp.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/logaddexp.cpp index 77ded230be..2dd585ab4e 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/logaddexp.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/logaddexp.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/logical_and.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/logical_and.cpp index 4c573ce508..f2680cbaff 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/logical_and.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/logical_and.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/logical_not.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/logical_not.cpp index 84362cd9ce..de9a48320b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/logical_not.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/logical_not.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/logical_or.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/logical_or.cpp index ebf8251b2e..15eb40d0f2 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/logical_or.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/logical_or.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/logical_xor.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/logical_xor.cpp index 9488a5615a..fd1853b927 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/logical_xor.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/logical_xor.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/maximum.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/maximum.cpp index 208bdcf47f..edcee5ded7 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/maximum.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/maximum.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/minimum.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/minimum.cpp index dc1a826ac4..ff0ee9ce9c 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/minimum.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/minimum.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp index c087abd9ff..0058dadcfc 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/multiply.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/negative.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/negative.cpp index bc659506d1..c10dfa0fc1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/negative.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/negative.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/not_equal.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/not_equal.cpp index a7a3e909cb..ba9fd3bc78 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/not_equal.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/not_equal.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/positive.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/positive.cpp index eaff0794d2..99cf8b821d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/positive.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/positive.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp index a8ef6cb171..a6cacba4ef 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/pow.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/proj.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/proj.cpp index 60060084e1..25a062785b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/proj.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/proj.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/real.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/real.cpp index 890a308a4e..2c63606fb6 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/real.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/real.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/reciprocal.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/reciprocal.cpp index 5f86188c99..5c717142d8 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/reciprocal.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/reciprocal.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp index 3255ea7e7f..e2b0e38061 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/remainder.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/round.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/round.cpp index cce730b899..510e1d53ca 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/round.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/round.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/rsqrt.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/rsqrt.cpp index 4661fdfa48..d5df041ee4 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/rsqrt.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/rsqrt.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/sign.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/sign.cpp index 7b7c2c22e5..352cc0e4e4 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/sign.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/sign.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/signbit.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/signbit.cpp index fc101dd64b..b8b917cbdd 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/signbit.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/signbit.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/sin.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/sin.cpp index 415dc15133..487bcfc0dd 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/sin.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/sin.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/sinh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/sinh.cpp index d9f92eb8f1..49064284ce 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/sinh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/sinh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/sqrt.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/sqrt.cpp index 159d45b51c..db04b01298 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/sqrt.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/sqrt.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/square.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/square.cpp index 184e09c19c..968262a7b0 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/square.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/square.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp index 9703182e7a..c720ab23a3 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/subtract.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/tan.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/tan.cpp index 2f1fbf55f2..c5dc7f4625 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/tan.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/tan.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/tanh.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/tanh.cpp index 033389e46d..398bac4097 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/tanh.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/tanh.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp index 22ad9bf3cb..9894ffa5a7 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/trunc.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/trunc.cpp index 5b2f451fb0..6f86b3e19c 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/trunc.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/trunc.cpp @@ -24,10 +24,10 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "elementwise_functions.hpp" diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index c8004bfae8..94d0a13100 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -35,6 +35,7 @@ #include "utils/type_utils.hpp" #include "full_ctor.hpp" +#include "unboxing_helper.hpp" namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; @@ -48,7 +49,60 @@ namespace py_internal using dpctl::utils::keep_args_alive; -using dpctl::tensor::kernels::constructors::full_contig_fn_ptr_t; +typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &, + size_t, + const py::object &, + char *, + const std::vector &); + +/*! + * @brief Function to submit kernel to fill given contiguous memory allocation + * with specified value. + * + * @param exec_q Sycl queue to which kernel is submitted for execution. + * @param nelems Length of the sequence + * @param py_value Python object representing the value to fill the array with. + * Must be convertible to `dstTy`. + * @param dst_p Kernel accessible USM pointer to the start of array to be + * populated. + * @param depends List of events to wait for before starting computations, if + * any. + * + * @return Event to wait on to ensure that computation completes. + * @defgroup CtorKernels + */ +template +sycl::event full_contig_impl(sycl::queue &exec_q, + size_t nelems, + const py::object &py_value, + char *dst_p, + const std::vector &depends) +{ + dstTy fill_v; + + PythonObjectUnboxer unboxer{}; + try { + fill_v = unboxer(py_value); + } catch (const py::error_already_set &e) { + throw; + } + + using dpctl::tensor::kernels::constructors::full_contig_impl; + + sycl::event fill_ev = + full_contig_impl(exec_q, nelems, fill_v, dst_p, depends); + + return fill_ev; +} + +template struct FullContigFactory +{ + fnT get() + { + fnT f = full_contig_impl; + return f; + } +}; static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types]; @@ -99,7 +153,6 @@ usm_ndarray_full(const py::object &py_value, void init_full_ctor_dispatch_vectors(void) { using namespace td_ns; - using dpctl::tensor::kernels::constructors::FullContigFactory; DispatchVectorBuilder dvb; diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp new file mode 100644 index 0000000000..9a2b51497e --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp @@ -0,0 +1,857 @@ +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "dot.hpp" +#include "dot_atomic_support.hpp" +#include "dot_dispatch.hpp" +#include "elementwise_functions/elementwise_functions_type_utils.hpp" +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" +#include "reductions/reduction_atomic_support.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +static int dot_output_id_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_impl_fn_ptr_t; +static dot_product_impl_fn_ptr_t dot_product_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static dot_product_impl_fn_ptr_t + dot_product_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_contig_impl_fn_ptr_t; +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_impl_fn_ptr_t; +static gemm_impl_fn_ptr_t gemm_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static gemm_impl_fn_ptr_t gemm_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_contig_impl_fn_ptr_t; +static gemm_contig_impl_fn_ptr_t + gemm_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_contig_impl_fn_ptr_t + gemm_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_impl_fn_ptr_t; +static gemm_batch_impl_fn_ptr_t + gemm_batch_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_impl_fn_ptr_t + gemm_batch_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_contig_impl_fn_ptr_t; +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void init_dot_dispatch_tables(void) +{ + using dpctl::tensor::py_internal::DotTypeMapFactory; + td_ns::DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(dot_output_id_table); + + using dpctl::tensor::py_internal::GemmBatchAtomicFactory; + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(gemm_batch_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(gemm_batch_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmAtomicFactory; + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(gemm_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(gemm_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchTempsFactory; + td_ns::DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(gemm_batch_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigTempsFactory; + td_ns::DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(gemm_batch_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmTempsFactory; + td_ns::DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(gemm_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigTempsFactory; + td_ns::DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(gemm_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductAtomicFactory; + td_ns::DispatchTableBuilder + dtb10; + dtb10.populate_dispatch_table(dot_product_dispatch_table); + + using dpctl::tensor::py_internal::DotProductNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb11; + dtb11.populate_dispatch_table(dot_product_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb12; + dtb12.populate_dispatch_table(dot_product_contig_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb13; + dtb13.populate_dispatch_table(dot_product_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t dot_atomic_support_vector[td_ns::num_types]; + +void init_dot_atomic_support_vector(void) +{ + + using atomic_support::DotAtomicSupportFactory; + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(dot_atomic_support_vector); +} + +std::pair +py_dot(const dpctl::tensor::usm_ndarray &x1, + const dpctl::tensor::usm_ndarray &x2, + int batch_dims, + int x1_outer_dims, + int x2_outer_dims, + int inner_dims, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + + if (!dst.is_writable()) { + throw py::value_error("Output array is read-only."); + } + + if (inner_dims == 0) { + throw py::value_error("No inner dimension for dot"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + if (x1_nd != (batch_dims + x1_outer_dims + inner_dims) || + x2_nd != (batch_dims + x2_outer_dims + inner_dims)) + { + throw py::value_error("Input arrays do not have dimensions consistent " + "with input dimensions"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != (batch_dims + x1_outer_dims + x2_outer_dims)) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of input dimensions"); + } + + const py::ssize_t *x1_shape_ptr = x1.get_shape_raw(); + const py::ssize_t *x2_shape_ptr = x2.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + size_t batches(1); + for (int i = 0; same_shapes && (i < batch_dims); ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]) && + (x2_shape_ptr[i] == dst_shape_ptr[i]); + batches *= x1_shape_ptr[i]; + } + size_t x1_outer_nelems(1); + for (int i = batch_dims; same_shapes && (i < (batch_dims + x1_outer_dims)); + ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]); + x1_outer_nelems *= x1_shape_ptr[i]; + } + size_t inner_nelems(1); + for (int i = batch_dims; i < (batch_dims + inner_dims); ++i) { + auto x1_shape_idx = x1_outer_dims + i; + same_shapes = + same_shapes && (x1_shape_ptr[x1_shape_idx] == x2_shape_ptr[i]); + inner_nelems *= x1_shape_ptr[x1_shape_idx]; + } + size_t x2_outer_nelems(1); + for (int i = 0; same_shapes && (i < x2_outer_dims); ++i) { + auto x2_shape_idx = batch_dims + inner_dims + i; + same_shapes = + same_shapes && (x2_shape_ptr[x2_shape_idx] == + dst_shape_ptr[batch_dims + x1_outer_dims + i]); + x2_outer_nelems *= x2_shape_ptr[x2_shape_idx]; + } + if (!same_shapes) { + throw py::value_error("Input arrays to tensor dot product do not have " + "appropriate shapes"); + } + + size_t dst_nelems = batches * x1_outer_nelems * x2_outer_nelems; + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + if (static_cast(dst.get_size()) != dst_nelems) { + throw py::value_error("dst shape and size mismatch"); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accommodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accommodate all the " + "array elements."); + } + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with x1 or x2 + if (overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Result array overlaps with inputs"); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = dot_output_id_table[x1_typeid][x2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Result array has unexpected elemental data type."); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + bool supports_atomics = + dot_atomic_support_vector[output_typeid](exec_q, usm_type); + + const char *x1_data = x1.get_data(); + const char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + auto x1_shape_vec = x1.get_shape_vector(); + auto x1_strides_vec = x1.get_strides_vector(); + + auto x2_shape_vec = x2.get_shape_vector(); + auto x2_strides_vec = x2.get_strides_vector(); + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + bool call_vecdot = ((x1_outer_dims == 0 && x1_outer_nelems == 1) && + (x2_outer_dims == 0 && x2_outer_nelems == 1)); + + bool call_batched = (batch_dims != 0 || batches > 1); + std::vector host_task_events{}; + sycl::event dot_ev; + if (call_vecdot) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig) || + ((is_x1_f_contig && is_x2_f_contig) && !call_batched)) + { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + zero_offset, // lhs batch offset + zero_offset, // rhs batch offset + zero_offset, // res batch offset + zero_offset, // lhs reduction offset + zero_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_3; + + int inner_nd = inner_dims; + const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims; + using shT = std::vector; + shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, + std::end(x1_strides_vec)); + shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, + std::end(x2_strides_vec)); + + shT simplified_inner_shape; + shT simplified_inner_x1_strides; + shT simplified_inner_x2_strides; + py::ssize_t inner_x1_offset(0); + py::ssize_t inner_x2_offset(0); + + simplify_iteration_space( + inner_nd, inner_shape_ptr, inner_x1_strides, inner_x2_strides, + // output + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides, inner_x1_offset, inner_x2_offset); + + const py::ssize_t *batch_shape_ptr = x1_shape_ptr; + + shT batch_x1_strides(std::begin(x1_strides_vec), + std::begin(x1_strides_vec) + batch_dims); + shT batch_x2_strides(std::begin(x2_strides_vec), + std::begin(x2_strides_vec) + batch_dims); + shT const &batch_dst_strides = dst_strides_vec; + + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t batch_x1_offset(0); + py::ssize_t batch_x2_offset(0); + py::ssize_t batch_dst_offset(0); + + if (batch_dims == 0) { + if (dst_nelems != 1) { + throw std::runtime_error( + "batch_dims == 0, but dst_nelems != 1"); + } + batch_dims = 1; + simplified_batch_shape.push_back(1); + simplified_batch_x1_strides.push_back(0); + simplified_batch_x2_strides.push_back(0); + simplified_batch_dst_strides.push_back(0); + } + else { + simplify_iteration_space_3( + batch_dims, batch_shape_ptr, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // output + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset); + } + + if (inner_nd == 1 && batch_dims == 1) { + bool dot_product_c_contig = false; + bool reduce_all_elems = false; + + if (simplified_inner_x1_strides[0] == 1 && + simplified_inner_x2_strides[0] == 1) { + reduce_all_elems = (simplified_batch_shape[0] == 1); + dot_product_c_contig = + (simplified_batch_dst_strides[0] == 1) && + (static_cast(simplified_batch_x1_strides[0]) == + inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + inner_nelems); + } + + if (dot_product_c_contig || reduce_all_elems) { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + batch_x1_offset, // lhs batch offset + batch_x2_offset, // rhs batch offset + batch_dst_offset, // res batch offset + inner_x1_offset, // lhs reduction offset + inner_x2_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + } + + dot_product_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = dot_product_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + // reduction metadata + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = + std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *inner_shape_stride = + temp_allocation_ptr + 4 * simplified_batch_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + dot_ev = + fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), x2.get_data(), + dst.get_data(), batch_dims, iter_shape_and_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset, + inner_nd, // number dimensions being reduced + inner_shape_stride, inner_x1_offset, inner_x2_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + } + else { // if (!call_vecdot) + if (!call_batched) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + gemm_contig_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = gemm_contig_temps_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + gemm_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, x1_shape_vec, x1_strides_vec, + x2_shape_vec, x2_strides_vec, dst_shape_vec, + dst_strides_vec); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + py::ssize_t *x1_shape_strides = packed_shapes_strides; + py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd); + py::ssize_t *dst_shape_strides = + packed_shapes_strides + 2 * (x1_nd + x2_nd); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + // change gemm calls to pass inner dims and outer dims separately + dot_ev = + fn(exec_q, x1_data, x2_data, dst_data, x1_outer_nelems, + inner_nelems, x2_outer_nelems, inner_dims, x1_outer_dims, + x1_shape_strides, x2_outer_dims, x2_shape_strides, + x1_outer_dims + x2_outer_dims, dst_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + else { // if (call_batched) + using shT = std::vector; + // temporary asserts for matmul + assert(x1_outer_dims == 1); + assert(x2_outer_dims == 1); + assert(inner_dims == 1); + + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + zero_offset, zero_offset, zero_offset, depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + + auto x1_outer_inner_dims = x1_nd - batch_dims; + auto x2_outer_inner_dims = x2_nd - batch_dims; + auto dst_outer_inner_dims = dst_nd - batch_dims; + + shT batch_x1_shape; + shT outer_inner_x1_shape; + shT batch_x1_strides; + shT outer_inner_x1_strides; + dpctl::tensor::py_internal::split_iteration_space( + x1_shape_vec, x1_strides_vec, batch_dims, + batch_dims + x1_outer_inner_dims, batch_x1_shape, + outer_inner_x1_shape, // 4 vectors modified + batch_x1_strides, outer_inner_x1_strides); + + shT batch_x2_shape; + shT outer_inner_x2_shape; + shT batch_x2_strides; + shT outer_inner_x2_strides; + dpctl::tensor::py_internal::split_iteration_space( + x2_shape_vec, x2_strides_vec, batch_dims, + batch_dims + x2_outer_inner_dims, batch_x2_shape, + outer_inner_x2_shape, // 4 vectors modified + batch_x2_strides, outer_inner_x2_strides); + + shT batch_dst_shape; + shT outer_inner_dst_shape; + shT batch_dst_strides; + shT outer_inner_dst_strides; + dpctl::tensor::py_internal::split_iteration_space( + dst_shape_vec, dst_strides_vec, batch_dims, + batch_dims + dst_outer_inner_dims, batch_dst_shape, + outer_inner_dst_shape, // 4 vectors modified + batch_dst_strides, outer_inner_dst_strides); + + using shT = std::vector; + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t x1_batch_offset(0); + py::ssize_t x2_batch_offset(0); + py::ssize_t dst_batch_offset(0); + + const py::ssize_t *shape = x1_shape_ptr; + + using dpctl::tensor::py_internal::simplify_iteration_space_3; + simplify_iteration_space_3( + batch_dims, shape, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // outputs + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset); + + if (batch_dims == 1 && x1_outer_dims == 1 && x2_outer_dims == 1 && + inner_dims == 1) + { + bool gemm_batch_c_contig = false; + + if ((static_cast(outer_inner_x1_strides[0]) == + inner_nelems && + outer_inner_x1_strides[1] == 1) && + (static_cast(outer_inner_x2_strides[0]) == + inner_nelems && + outer_inner_x2_strides[1] == 1) && + (static_cast(outer_inner_dst_strides[0]) == + x2_outer_nelems && + outer_inner_dst_strides[1] == 1)) + { + gemm_batch_c_contig = + (static_cast(simplified_batch_x1_strides[0]) == + x1_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + x2_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_dst_strides[0]) == + x1_outer_nelems * x2_outer_nelems); + } + + if (gemm_batch_c_contig) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + x1_batch_offset, x2_batch_offset, + dst_batch_offset, depends); + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, + {dot_ev}), + dot_ev); + } + } + } + + gemm_batch_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_batch_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, simplified_batch_shape, + simplified_batch_x1_strides, simplified_batch_x2_strides, + simplified_batch_dst_strides, outer_inner_x1_shape, + outer_inner_x1_strides, outer_inner_x2_shape, + outer_inner_x2_strides, outer_inner_dst_shape, + outer_inner_dst_strides, + // full shape and strides of the result array + // necessary for reduction and initialization + simplified_batch_shape, outer_inner_dst_shape, + simplified_batch_dst_strides, outer_inner_dst_strides); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + + auto batch_shape_strides = packed_shapes_strides; + auto x1_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims; + auto x2_outer_inner_shapes_strides = packed_shapes_strides + + 4 * batch_dims + + 2 * (x1_outer_inner_dims); + auto dst_outer_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims); + auto dst_full_shape_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) + + 2 * (dst_outer_inner_dims); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + dot_ev = fn( + exec_q, x1_data, x2_data, dst_data, batches, x1_outer_nelems, + inner_nelems, x2_outer_nelems, batch_dims, batch_shape_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset, inner_dims, + x1_outer_dims, x1_outer_inner_shapes_strides, x2_outer_dims, + x2_outer_inner_shapes_strides, x1_outer_dims + x2_outer_dims, + dst_outer_shapes_strides, dst_full_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + } + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, host_task_events), + dot_ev); +} + +template +py::object py_dot_result_type(const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum; + + auto dst_typenum_t = static_cast(dst_typeid); + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +void init_dot(py::module_ m) +{ + using dpctl::tensor::py_internal::init_dot_atomic_support_vector; + init_dot_atomic_support_vector(); + using dpctl::tensor::py_internal::init_dot_dispatch_tables; + init_dot_dispatch_tables(); + + using dpctl::tensor::py_internal::py_dot; + m.def("_dot", &py_dot, "", py::arg("x1"), py::arg("x2"), + py::arg("batch_dims"), py::arg("x1_outer_dims"), + py::arg("x2_outer_dims"), py::arg("inner_dims"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using dpctl::tensor::py_internal::dot_output_id_table; + auto dot_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + using dpctl::tensor::py_internal::py_dot_result_type; + return py_dot_result_type(dtype1, dtype2, dot_output_id_table); + }; + m.def("_dot_result_type", dot_result_type_pyapi, ""); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp new file mode 100644 index 0000000000..5f8f6cf494 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_dot(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp new file mode 100644 index 0000000000..29022342a1 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "reductions/reduction_atomic_support.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ +namespace atomic_support +{ + +template struct DotAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + return atomic_support::fixed_decision; + } + else { + return atomic_support::check_atomic_support; + } + } +}; + +} // namespace atomic_support +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp new file mode 100644 index 0000000000..de59450174 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template struct DotAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; +}; + +// add separate type support lists for atomic vs. temps +// gemm, gevm, and dot product share output type struct +template struct DotNoAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; +}; + +template struct DotTypeMapFactory +{ + /*! @brief get typeid for output type of kernels called by py_dot */ + std::enable_if_t::value, int> get() + { + using rT1 = typename DotNoAtomicOutputType::value_type; + using rT2 = typename DotAtomicOutputType::value_type; + static_assert(std::is_same_v || std::is_same_v); + return td_ns::GetTypeid{}.get(); + } +}; + +template struct GemmBatchAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_impl; + fnT fn = gemm_batch_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_impl; + fnT fn = gemm_batch_contig_impl; + return fn; + } + } +}; + +template struct GemmAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_impl; + fnT fn = gemm_impl; + return fn; + } + } +}; + +template struct GemmContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_impl; + fnT fn = gemm_contig_impl; + return fn; + } + } +}; + +template struct GemmTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_tree_impl; + fnT fn = gemm_tree_impl; + return fn; + } + } +}; + +template struct GemmContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_tree_impl; + fnT fn = gemm_contig_tree_impl; + return fn; + } + } +}; + +template struct GemmBatchTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_tree_impl; + fnT fn = gemm_batch_tree_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_tree_impl; + fnT fn = gemm_batch_contig_tree_impl; + return fn; + } + } +}; + +template struct DotProductAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_impl; + fnT fn = dot_product_impl; + return fn; + } + } +}; + +template +struct DotProductNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_tree_impl; + fnT fn = dot_product_tree_impl; + return fn; + } + } +}; + +template +struct DotProductContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_impl; + fnT fn = dot_product_contig_impl; + return fn; + } + } +}; + +template +struct DotProductContigNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_tree_impl; + fnT fn = dot_product_contig_tree_impl; + return fn; + } + } +}; + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index 72d292df5f..e3f804b43a 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -35,6 +35,7 @@ #include "utils/type_utils.hpp" #include "linear_sequences.hpp" +#include "unboxing_helper.hpp" namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; @@ -46,13 +47,121 @@ namespace tensor namespace py_internal { -using dpctl::utils::keep_args_alive; +// Constructor to populate tensor with linear sequence defined by +// start and step data + +typedef sycl::event (*lin_space_step_fn_ptr_t)( + sycl::queue &, + size_t, // num_elements + const py::object &start, + const py::object &step, + char *, // dst_data_ptr + const std::vector &); + +/*! + * @brief Function to submit kernel to populate given contiguous memory + * allocation with linear sequence specified by starting value and increment + * given as Python objects. + * + * @param q Sycl queue to which the kernel is submitted + * @param nelems Length of the sequence + * @param start Starting value of the sequence as Python object. Must be + * convertible to array element data type `Ty`. + * @param step Increment of the sequence as Python object. Must be convertible + * to array element data type `Ty`. + * @param array_data Kernel accessible USM pointer to the start of array to be + * populated. + * @param depends List of events to wait for before starting computations, if + * any. + * + * @return Event to wait on to ensure that computation completes. + * @defgroup CtorKernels + */ +template +sycl::event lin_space_step_impl(sycl::queue &exec_q, + size_t nelems, + const py::object &start, + const py::object &step, + char *array_data, + const std::vector &depends) +{ + Ty start_v; + Ty step_v; + + const auto &unboxer = PythonObjectUnboxer{}; + try { + start_v = unboxer(start); + step_v = unboxer(step); + } catch (const py::error_already_set &e) { + throw; + } -using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t; + using dpctl::tensor::kernels::constructors::lin_space_step_impl; -static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[td_ns::num_types]; + auto lin_space_step_event = lin_space_step_impl( + exec_q, nelems, start_v, step_v, array_data, depends); + + return lin_space_step_event; +} + +typedef sycl::event (*lin_space_affine_fn_ptr_t)( + sycl::queue &, + size_t, // num_elements + const py::object &start, + const py::object &end, + bool include_endpoint, + char *, // dst_data_ptr + const std::vector &); + +/*! + * @brief Function to submit kernel to populate given contiguous memory + * allocation with linear sequence specified by starting and end values given + * as Python objects. + * + * @param exec_q Sycl queue to which kernel is submitted for execution. + * @param nelems Length of the sequence + * @param start Stating value of the sequence as Python object. Must be + * convertible to array data element type `Ty`. + * @param end End-value of the sequence as Python object. Must be convertible + * to array data element type `Ty`. + * @param include_endpoint Whether the end-value is included in the sequence + * @param array_data Kernel accessible USM pointer to the start of array to be + * populated. + * @param depends List of events to wait for before starting computations, if + * any. + * + * @return Event to wait on to ensure that computation completes. + * @defgroup CtorKernels + */ +template +sycl::event lin_space_affine_impl(sycl::queue &exec_q, + size_t nelems, + const py::object &start, + const py::object &end, + bool include_endpoint, + char *array_data, + const std::vector &depends) +{ + Ty start_v, end_v; + const auto &unboxer = PythonObjectUnboxer{}; + try { + start_v = unboxer(start); + end_v = unboxer(end); + } catch (const py::error_already_set &e) { + throw; + } + + using dpctl::tensor::kernels::constructors::lin_space_affine_impl; + + auto lin_space_affine_event = lin_space_affine_impl( + exec_q, nelems, start_v, end_v, include_endpoint, array_data, depends); -using dpctl::tensor::kernels::constructors::lin_space_affine_fn_ptr_t; + return lin_space_affine_event; +} + +using dpctl::utils::keep_args_alive; + +static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[td_ns::num_types]; static lin_space_affine_fn_ptr_t lin_space_affine_dispatch_vector[td_ns::num_types]; @@ -153,11 +262,36 @@ usm_ndarray_linear_sequence_affine(const py::object &start, linspace_affine_event); } +/*! + * @brief Factor to get function pointer of type `fnT` for array with elements + * of type `Ty`. + * @defgroup CtorKernels + */ +template struct LinSpaceStepFactory +{ + fnT get() + { + fnT f = lin_space_step_impl; + return f; + } +}; + +/*! + * @brief Factory to get function pointer of type `fnT` for array data type + * `Ty`. + */ +template struct LinSpaceAffineFactory +{ + fnT get() + { + fnT f = lin_space_affine_impl; + return f; + } +}; + void init_linear_sequences_dispatch_vectors(void) { using namespace td_ns; - using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory; - using dpctl::tensor::kernels::constructors::LinSpaceStepFactory; DispatchVectorBuilder diff --git a/dpctl/tensor/libtensor/source/reductions/argmax.cpp b/dpctl/tensor/libtensor/source/reductions/argmax.cpp index 1d83bf9c2d..d3e2460081 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmax.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmax.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/argmin.cpp b/dpctl/tensor/libtensor/source/reductions/argmin.cpp index c6469e6864..57d0a9ccd2 100644 --- a/dpctl/tensor/libtensor/source/reductions/argmin.cpp +++ b/dpctl/tensor/libtensor/source/reductions/argmin.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp index e3b015a4e0..4936edd17f 100644 --- a/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp +++ b/dpctl/tensor/libtensor/source/reductions/logsumexp.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/max.cpp b/dpctl/tensor/libtensor/source/reductions/max.cpp index 32c60b943b..0166857039 100644 --- a/dpctl/tensor/libtensor/source/reductions/max.cpp +++ b/dpctl/tensor/libtensor/source/reductions/max.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/min.cpp b/dpctl/tensor/libtensor/source/reductions/min.cpp index de1a81387d..f36cff2bcf 100644 --- a/dpctl/tensor/libtensor/source/reductions/min.cpp +++ b/dpctl/tensor/libtensor/source/reductions/min.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index a90d04304a..66c8bb35be 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp index c7313930b4..e7e80cc680 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp +++ b/dpctl/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp index 2478545efe..a6d7a274fb 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp @@ -23,8 +23,8 @@ //===--------------------------------------------------------------------===// #pragma once -#include #include +#include #include #include "utils/type_utils.hpp" diff --git a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp index 5aafe38a40..e6da120821 100644 --- a/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp +++ b/dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -25,9 +25,9 @@ #pragma once -#include #include #include +#include #include #include #include diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index 33803cfd7b..81130e9abd 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include "kernels/reductions.hpp" diff --git a/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp b/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp index 62b30ccb8f..bda9227b71 100644 --- a/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp +++ b/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp @@ -41,7 +41,7 @@ template struct ExtendedRealFPLess /* [R, nan] */ bool operator()(const fpT v1, const fpT v2) const { - return (!sycl::isnan(v1) && (sycl::isnan(v2) || (v1 < v2))); + return (!std::isnan(v1) && (std::isnan(v2) || (v1 < v2))); } }; @@ -49,7 +49,7 @@ template struct ExtendedRealFPGreater { bool operator()(const fpT v1, const fpT v2) const { - return (!sycl::isnan(v2) && (sycl::isnan(v1) || (v2 < v1))); + return (!std::isnan(v2) && (std::isnan(v1) || (v2 < v1))); } }; @@ -64,14 +64,14 @@ template struct ExtendedComplexFPLess const realT real1 = std::real(v1); const realT real2 = std::real(v2); - const bool r1_nan = sycl::isnan(real1); - const bool r2_nan = sycl::isnan(real2); + const bool r1_nan = std::isnan(real1); + const bool r2_nan = std::isnan(real2); const realT imag1 = std::imag(v1); const realT imag2 = std::imag(v2); - const bool i1_nan = sycl::isnan(imag1); - const bool i2_nan = sycl::isnan(imag2); + const bool i1_nan = std::isnan(imag1); + const bool i2_nan = std::isnan(imag2); const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0); const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0); diff --git a/dpctl/tensor/libtensor/source/tensor_ctors.cpp b/dpctl/tensor/libtensor/source/tensor_ctors.cpp index be2b20c18d..0f1f7a81fc 100644 --- a/dpctl/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpctl/tensor/libtensor/source/tensor_ctors.cpp @@ -46,6 +46,7 @@ #include "eye_ctor.hpp" #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "linear_sequences.hpp" #include "repeat.hpp" #include "simplify_iteration_space.hpp" @@ -56,6 +57,8 @@ namespace py = pybind11; +static_assert(std::is_same_v); + namespace { diff --git a/dpctl/tensor/libtensor/source/tensor_linalg.cpp b/dpctl/tensor/libtensor/source/tensor_linalg.cpp new file mode 100644 index 0000000000..82c9893c08 --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_linalg.cpp @@ -0,0 +1,34 @@ +//===-- tensor_linalg.cpp ---*-C++-*-/===// +// Implementation of _tensor_linalg_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include "linalg_functions/dot.hpp" +#include + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_linalg_impl, m) +{ + dpctl::tensor::py_internal::init_dot(m); +} diff --git a/dpctl/tensor/libtensor/source/unboxing_helper.hpp b/dpctl/tensor/libtensor/source/unboxing_helper.hpp new file mode 100644 index 0000000000..d7082c3e13 --- /dev/null +++ b/dpctl/tensor/libtensor/source/unboxing_helper.hpp @@ -0,0 +1,53 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template struct PythonObjectUnboxer +{ + T operator()(const py::object &o) const + { + if constexpr (std::is_same_v) { + float tmp = py::cast(o); + return static_cast(tmp); + } + else { + return py::cast(o); + } + } +}; + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py index ca296911a0..a82b23c445 100644 --- a/dpctl/tests/elementwise/test_abs.py +++ b/dpctl/tests/elementwise/test_abs.py @@ -70,9 +70,8 @@ def test_abs_usm_type(usm_type): assert np.allclose(dpt.asnumpy(Y), expected_Y) -def test_abs_types_prop(): - types = dpt.abs.types_ - assert types is None +def test_abs_types_property(): + get_queue_or_skip() types = dpt.abs.types assert isinstance(types, list) assert len(types) > 0 diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index accfbd8032..1d605248c5 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -259,8 +259,7 @@ def __sycl_usm_array_interface__(self): def test_add_types_property(): - types = dpt.add.types_ - assert types is None + get_queue_or_skip() types = dpt.add.types assert isinstance(types, list) assert len(types) > 0 diff --git a/dpctl/tests/elementwise/test_elementwise_classes.py b/dpctl/tests/elementwise/test_elementwise_classes.py index 634b1fbdea..c8680078b7 100644 --- a/dpctl/tests/elementwise/test_elementwise_classes.py +++ b/dpctl/tests/elementwise/test_elementwise_classes.py @@ -15,6 +15,7 @@ # limitations under the License. import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip unary_fn = dpt.negative binary_fn = dpt.divide @@ -29,6 +30,7 @@ def test_unary_class_getters(): def test_unary_class_types_property(): + get_queue_or_skip() loop_types = unary_fn.types assert isinstance(loop_types, list) assert len(loop_types) > 0 @@ -62,6 +64,7 @@ def test_binary_class_getters(): def test_binary_class_types_property(): + get_queue_or_skip() loop_types = binary_fn.types assert isinstance(loop_types, list) assert len(loop_types) > 0 diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index 25f986806c..593363c28d 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -241,7 +241,10 @@ def test_can_cast_device(): def test_acceptance_fns(): """Check type promotion acceptance functions""" - dev = dpctl.SyclDevice() + try: + dev = dpctl.SyclDevice() + except dpctl.SyclDeviceCreationError: + pytest.skip("Default device is not available") assert tu._acceptance_fn_reciprocal( dpt.float32, dpt.float32, dpt.float32, dev ) diff --git a/dpctl/tests/test_service.py b/dpctl/tests/test_service.py index f16ff51227..328e1371a0 100644 --- a/dpctl/tests/test_service.py +++ b/dpctl/tests/test_service.py @@ -173,13 +173,30 @@ def test_syclinterface(): raise RuntimeError("Unsupported system") +def test_main_include_dir(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "--include-dir"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + dir_path = res.stdout.decode("utf-8").strip() + assert os.path.exists(dir_path) + + def test_main_includes(): res = subprocess.run( [sys.executable, "-m", "dpctl", "--includes"], capture_output=True ) assert res.returncode == 0 assert res.stdout - assert res.stdout.decode("utf-8").startswith("-I") + flags = res.stdout.decode("utf-8") + res = subprocess.run( + [sys.executable, "-m", "dpctl", "--include-dir"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + dir = res.stdout.decode("utf-8") + assert flags == "-I " + dir def test_main_library(): @@ -191,6 +208,34 @@ def test_main_library(): assert res.stdout.decode("utf-8").startswith("-L") +def test_tensor_includes(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "--tensor-includes"], + capture_output=True, + ) + assert res.returncode == 0 + assert res.stdout + flags = res.stdout.decode("utf-8") + res = subprocess.run( + [sys.executable, "-m", "dpctl", "--tensor-include-dir"], + capture_output=True, + ) + assert res.returncode == 0 + assert res.stdout + dir = res.stdout.decode("utf-8") + assert flags == "-I " + dir + + +def test_main_library_dir(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "--library-dir"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + dir_path = res.stdout.decode("utf-8").strip() + assert os.path.exists(dir_path) + + def test_cmakedir(): res = subprocess.run( [sys.executable, "-m", "dpctl", "--cmakedir"], capture_output=True @@ -198,7 +243,7 @@ def test_cmakedir(): assert res.returncode == 0 assert res.stdout cmake_dir = res.stdout.decode("utf-8").strip() - assert os.path.exists(os.path.join(cmake_dir, "FindDpctl.cmake")) + assert os.path.exists(os.path.join(cmake_dir, "dpctl-config.cmake")) def test_main_full_list(): diff --git a/dpctl/tests/test_tensor_array_api_inspection.py b/dpctl/tests/test_tensor_array_api_inspection.py index 5ae0d35f8e..fc2495fe37 100644 --- a/dpctl/tests/test_tensor_array_api_inspection.py +++ b/dpctl/tests/test_tensor_array_api_inspection.py @@ -50,22 +50,29 @@ def __init__(self, fp16: bool, fp64: bool): def test_array_api_inspection_methods(): info = dpt.__array_namespace_info__() assert info.capabilities() - assert info.default_device() + try: + assert info.default_device() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") assert info.default_dtypes() assert info.devices() assert info.dtypes() def test_array_api_inspection_default_device(): - assert ( - dpt.__array_namespace_info__().default_device() - == dpctl.select_default_device() - ) + try: + dev = dpctl.select_default_device() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") + assert dpt.__array_namespace_info__().default_device() == dev def test_array_api_inspection_devices(): + try: + devices2 = dpctl.get_devices() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") devices1 = dpt.__array_namespace_info__().devices() - devices2 = dpctl.get_devices() assert len(devices1) == len(devices2) assert devices1 == devices2 @@ -77,7 +84,10 @@ def test_array_api_inspection_capabilities(): def test_array_api_inspection_default_dtypes(): - dev = dpctl.select_default_device() + try: + dev = dpctl.select_default_device() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") int_dt = default_device_int_type(dev) ind_dt = default_device_index_type(dev) @@ -107,7 +117,10 @@ def test_array_api_inspection_default_dtypes(): def test_array_api_inspection_default_device_dtypes(): - dev = dpctl.select_default_device() + try: + dev = dpctl.select_default_device() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") dtypes = _dtypes_no_fp16_fp64.copy() if dev.has_aspect_fp64: dtypes["float64"] = dpt.float64 @@ -128,6 +141,10 @@ def test_array_api_inspection_device_dtypes(fp16, fp64): def test_array_api_inspection_dtype_kind(): info = dpt.__array_namespace_info__() + try: + info.default_device() + except dpctl.SyclDeviceCreationError: + pytest.skip("No default device available") f_dtypes = info.dtypes(kind="real floating") assert all([_dt[1].kind == "f" for _dt in f_dtypes.items()]) diff --git a/dpctl/tests/test_tensor_asarray.py b/dpctl/tests/test_tensor_asarray.py index e73c35ce26..8278bd34eb 100644 --- a/dpctl/tests/test_tensor_asarray.py +++ b/dpctl/tests/test_tensor_asarray.py @@ -164,6 +164,21 @@ def test_asarray_input_validation(): with pytest.raises(ValueError): # sequence is not rectangular dpt.asarray([[1], 2]) + with pytest.raises(OverflowError): + # Python int too large for type + dpt.asarray(-9223372036854775809, dtype="i4") + with pytest.raises(ValueError): + # buffer to usm_ndarray requires a copy + dpt.asarray(memoryview(np.arange(5)), copy=False) + with pytest.raises(ValueError): + # Numpy array to usm_ndarray requires a copy + dpt.asarray(np.arange(5), copy=False) + with pytest.raises(ValueError): + # Python sequence to usm_ndarray requires a copy + dpt.asarray([1, 2, 3], copy=False) + with pytest.raises(ValueError): + # Python scalar to usm_ndarray requires a copy + dpt.asarray(5, copy=False) def test_asarray_input_validation2(): diff --git a/dpctl/tests/test_tensor_clip.py b/dpctl/tests/test_tensor_clip.py index 39ba35a4a1..11c93ecf1f 100644 --- a/dpctl/tests/test_tensor_clip.py +++ b/dpctl/tests/test_tensor_clip.py @@ -21,7 +21,12 @@ import dpctl import dpctl.tensor as dpt -from dpctl.tensor._type_utils import _can_cast +from dpctl.tensor._elementwise_common import _get_dtype +from dpctl.tensor._type_utils import ( + _can_cast, + _strong_dtype_num_kind, + _weak_type_num_kind, +) from dpctl.utils import ExecutionPlacementError _all_dtypes = [ @@ -194,6 +199,15 @@ def test_clip_out_need_temporary(): dpt.clip(x[:6], 2, 3, out=x[-6:]) assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2) + x = dpt.arange(12, dtype="i4") + dpt.clip(x[:6], out=x[-6:]) + expected = dpt.arange(6, dtype="i4") + assert dpt.all(x[:-6] == expected) and dpt.all(x[-6:] == expected) + + x = dpt.ones(10, dtype="i4") + dpt.clip(x, out=x) + assert dpt.all(x == 1) + x = dpt.full(6, 3, dtype="i4") a_min = dpt.full(10, 2, dtype="i4") a_max = dpt.asarray(4, dtype="i4") @@ -227,6 +241,21 @@ def test_clip_arg_validation(): with pytest.raises(TypeError): dpt.clip(check, x1, x2) + with pytest.raises(ValueError): + dpt.clip(x1, check, x2) + + with pytest.raises(ValueError): + dpt.clip(x1, check) + + with pytest.raises(TypeError): + dpt.clip(x1, x1, x2, out=check) + + with pytest.raises(TypeError): + dpt.clip(x1, x2, out=check) + + with pytest.raises(TypeError): + dpt.clip(x1, out=check) + @pytest.mark.parametrize( "dt1,dt2", [("i4", "i4"), ("i4", "i2"), ("i2", "i4"), ("i1", "i2")] @@ -599,22 +628,40 @@ def test_clip_max_less_than_min(): assert dpt.all(res == 0) -def test_clip_minmax_weak_types(): +@pytest.mark.parametrize("dt", ["?", "i4", "f4", "c8"]) +def test_clip_minmax_weak_types(dt): get_queue_or_skip() - x = dpt.zeros(10, dtype=dpt.bool) + x = dpt.zeros(10, dtype=dt) min_list = [False, 0, 0.0, 0.0 + 0.0j] max_list = [True, 1, 1.0, 1.0 + 0.0j] + for min_v, max_v in zip(min_list, max_list): - if isinstance(min_v, bool) and isinstance(max_v, bool): - y = dpt.clip(x, min_v, max_v) - assert isinstance(y, dpt.usm_ndarray) + st_dt = _strong_dtype_num_kind(dpt.dtype(dt)) + wk_dt1 = _weak_type_num_kind(_get_dtype(min_v, x.sycl_device)) + wk_dt2 = _weak_type_num_kind(_get_dtype(max_v, x.sycl_device)) + + if st_dt >= wk_dt1 and st_dt >= wk_dt2: + r = dpt.clip(x, min_v, max_v) + assert isinstance(r, dpt.usm_ndarray) else: with pytest.raises(ValueError): dpt.clip(x, min_v, max_v) + if st_dt >= wk_dt1: + r = dpt.clip(x, min_v) + assert isinstance(r, dpt.usm_ndarray) + + r = dpt.clip(x, None, min_v) + assert isinstance(r, dpt.usm_ndarray) + else: + with pytest.raises(ValueError): + dpt.clip(x, min_v) + with pytest.raises(ValueError): + dpt.clip(x, None, max_v) + -def test_clip_max_weak_types(): +def test_clip_max_weak_type_errors(): get_queue_or_skip() x = dpt.zeros(10, dtype="i4") @@ -626,6 +673,15 @@ def test_clip_max_weak_types(): with pytest.raises(ValueError): dpt.clip(x, 2.5, m) + with pytest.raises(ValueError): + dpt.clip(x, 2.5) + + with pytest.raises(ValueError): + dpt.clip(dpt.astype(x, "?"), 2) + + with pytest.raises(ValueError): + dpt.clip(dpt.astype(x, "f4"), complex(2)) + def test_clip_unaligned(): get_queue_or_skip() @@ -636,3 +692,59 @@ def test_clip_unaligned(): expected = dpt.full(512, 2, dtype="i4") assert dpt.all(dpt.clip(x[1:], a_min, a_max) == expected) + + +def test_clip_none_args(): + get_queue_or_skip() + + x = dpt.arange(10, dtype="i4") + r = dpt.clip(x) + assert dpt.all(x == r) + + +def test_clip_shape_errors(): + get_queue_or_skip() + + x = dpt.ones((4, 4), dtype="i4") + a_min = dpt.ones(5, dtype="i4") + a_max = dpt.ones(5, dtype="i4") + + with pytest.raises(ValueError): + dpt.clip(x, a_min, a_max) + + with pytest.raises(ValueError): + dpt.clip(x, a_min) + + with pytest.raises(ValueError): + dpt.clip(x, 0, 1, out=a_min) + + with pytest.raises(ValueError): + dpt.clip(x, 0, out=a_min) + + with pytest.raises(ValueError): + dpt.clip(x, out=a_min) + + +def test_clip_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = get_queue_or_skip() + + x = dpt.ones(10, dtype="i4", sycl_queue=q1) + a_min = dpt.ones(10, dtype="i4", sycl_queue=q2) + a_max = dpt.ones(10, dtype="i4", sycl_queue=q1) + res = dpt.empty_like(x, sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.clip(x, a_min, a_max) + + with pytest.raises(ExecutionPlacementError): + dpt.clip(x, dpt.ones_like(x), a_max, out=res) + + with pytest.raises(ExecutionPlacementError): + dpt.clip(x, a_min) + + with pytest.raises(ExecutionPlacementError): + dpt.clip(x, None, a_max, out=res) + + with pytest.raises(ExecutionPlacementError): + dpt.clip(x, out=res) diff --git a/dpctl/tests/test_tensor_statistical_functions.py b/dpctl/tests/test_tensor_statistical_functions.py index 8916833f86..bb341601a2 100644 --- a/dpctl/tests/test_tensor_statistical_functions.py +++ b/dpctl/tests/test_tensor_statistical_functions.py @@ -234,6 +234,7 @@ def test_stat_function_errors(): with pytest.raises(TypeError): dpt.mean(d) + get_queue_or_skip() x = dpt.empty(1, dtype="f4") with pytest.raises(TypeError): dpt.var(x, axis=d) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7227e687af..7c5765332b 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -430,6 +430,7 @@ def test_ctor_invalid_shape(): def test_ctor_invalid_order(): + get_queue_or_skip() with pytest.raises(ValueError): dpt.usm_ndarray((5, 5, 3), order="Z") @@ -1035,6 +1036,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type): def test_setitem_broadcasting(): + "See gh-1503" get_queue_or_skip() dst = dpt.ones((2, 3, 4), dtype="u4") src = dpt.zeros((3, 1), dtype=dst.dtype) @@ -1043,6 +1045,16 @@ def test_setitem_broadcasting(): assert np.array_equal(dpt.asnumpy(dst), expected) +def test_setitem_broadcasting_offset(): + get_queue_or_skip() + dt = dpt.int32 + x = dpt.asarray([[1, 2, 3], [6, 7, 8]], dtype=dt) + y = dpt.asarray([4, 5], dtype=dt) + x[0] = y[1] + expected = dpt.asarray([[5, 5, 5], [6, 7, 8]], dtype=dt) + assert dpt.all(x == expected) + + def test_setitem_broadcasting_empty_dst_validation(): "Broadcasting rules apply, except exception" get_queue_or_skip() @@ -1301,6 +1313,20 @@ def test_astype_invalid_order(): dpt.astype(X, "i4", order="WRONG") +def test_astype_device(): + get_queue_or_skip() + q1 = dpctl.SyclQueue() + q2 = dpctl.SyclQueue() + + x = dpt.arange(5, dtype="i4", sycl_queue=q1) + r = dpt.astype(x, "f4") + assert r.sycl_queue == x.sycl_queue + assert r.sycl_device == x.sycl_device + + r = dpt.astype(x, "f4", device=q2) + assert r.sycl_queue == q2 + + def test_copy(): try: X = dpt.usm_ndarray((5, 5), "i4")[2:4, 1:4] diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 9183226be2..5da827602c 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1302,7 +1302,7 @@ def test_nonzero(): def test_nonzero_f_contig(): "See gh-1370" - get_queue_or_skip + get_queue_or_skip() mask = dpt.zeros((5, 5), dtype="?", order="F") mask[2, 3] = True @@ -1319,7 +1319,7 @@ def test_nonzero_compacting(): Test with input where dimensionality of iteration space is compacted from 3d to 2d """ - get_queue_or_skip + get_queue_or_skip() mask = dpt.zeros((5, 5, 5), dtype="?", order="F") mask[3, 2, 1] = True diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 4023eb8ad7..b4d0ae96f1 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -14,10 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + +import numpy as np import pytest +import dpctl import dpctl.tensor as dpt -from dpctl.tests.helper import get_queue_or_skip +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError + +_numeric_types = [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] def test_matrix_transpose(): @@ -46,3 +67,781 @@ def test_matrix_transpose_arg_validation(): X = dpt.empty((5, 5), dtype="i4") assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_simple(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n, m = 235, 17 + m1 = dpt.ones((m, n), dtype=dtype) + m2 = dpt.ones((n, m), dtype=dtype) + + for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: + r = dpt.matmul(m1[:k, :], m2[:, :k]) + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent1(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 77 + N_mat = dpt.eye(n, k=1, dtype=dtype) + I_mat = dpt.eye(n, dtype=dtype) + R_mat = dpt.eye(n, dtype=dtype) + for _ in range(n + 1): + R_mat = I_mat + dpt.matmul(N_mat, R_mat) + + assert dpt.allclose(dpt.matmul(I_mat - N_mat, R_mat), I_mat) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent2(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 128 + u = dpt.ones((n, 1), dtype=dtype) + v = dpt.ones((1, n), dtype=dtype) + + uv = dpt.matmul(u, v) + uv_ref = u * v + + assert dpt.allclose(uv, uv_ref) + + +def test_matmul_null_axis(): + get_queue_or_skip() + n = 3 + + A_mat = dpt.ones((n, 0), dtype="f4") + B_mat = dpt.ones((0, 1), dtype="f4") + + R_mat = dpt.matmul(A_mat, B_mat) + assert R_mat.shape == (n, 1) + + R_mat = dpt.matmul(A_mat, B_mat[:, :0]) + assert R_mat.shape == (n, 0) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_matmul_dims(dtype): + get_queue_or_skip() + + n, m, k, b = 4, 5, 7, 3 + v = dpt.ones(k, dtype=dtype) + m1 = dpt.ones((n, k), dtype=dtype) + m2 = dpt.ones((k, m), dtype=dtype) + st1 = dpt.ones((b, n, k), dtype=dtype) + st2 = dpt.ones((b, k, m), dtype=dtype) + + r = dpt.matmul(v, v) + assert r.shape == tuple() + assert dpt.round(r) == k + + r = dpt.matmul(m1, v) + assert r.shape == (n,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, m2) + assert r.shape == (m,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, m2) + assert r.shape == ( + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, st2) + assert r.shape == ( + b, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, v) + assert r.shape == ( + b, + n, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, m2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + +def test_matmul_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.matmul(s1, v2) + + with pytest.raises(ValueError): + dpt.matmul(v1, s2) + + with pytest.raises(TypeError): + dpt.matmul(dict(), v2) + + with pytest.raises(TypeError): + dpt.matmul(v2, None) + + +def test_matmul_dims_validation(): + get_queue_or_skip() + + m1 = dpt.ones((16, 16)) + m2 = dpt.ones((16, 16)) + + # contraction dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1[:, :7], m2[:3, :]) + + m1 = dpt.ones((3, 4, 5)) + m2 = dpt.ones((2, 5, 3)) + # broadcasting dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1, m2) + + +def test_matmul_broadcasting(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.uint32), + ]: + m1 = dpt.ones((7, 11, 16), dtype=dt1) + m2 = dpt.ones((16, 13), dtype=dt2) + + r = dpt.matmul(m1, m2[dpt.newaxis, ...]) + + assert r.shape == (7, 11, 13) + + +@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) +def test_matmul_strided(dtype): + get_queue_or_skip() + + m1_shape = (14, 22, 32) + m1_size = 1 + for el in m1_shape: + m1_size = m1_size * el + + m1 = dpt.remainder(dpt.arange(1, m1_size + 1, dtype="i8"), 13) + m1_orig = dpt.reshape(dpt.astype(m1, dtype), m1_shape) + m2_orig = dpt.ones((14, 16, 13), dtype=dtype) + + m1 = m1_orig[::2, ::-2, ::2] + m2 = m2_orig[::2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::2, ::2, ::-2] + m2 = m2_orig[::2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::-2, ::2, ::2] + m2 = m2_orig[::-2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + +def test_matmul_out(): + get_queue_or_skip() + + m1 = ( + dpt.arange(14, dtype="f4")[:, dpt.newaxis, dpt.newaxis] + + dpt.arange(17, dtype="f4")[dpt.newaxis, :, dpt.newaxis] + + dpt.arange(128, dtype="f4")[dpt.newaxis, dpt.newaxis, :] + ) + assert m1.shape == (14, 17, 128) + m2 = dpt.tile( + dpt.reshape(dpt.asarray([1, 2], dtype="f4"), (2, 1, 1)), (7, 128, 13) + ) + assert m2.shape == (14, 128, 13) + + buf = dpt.zeros((2 * 14, 3 * 17, 13), dtype="f4") + res = dpt.matmul(m1, m2, out=buf[::-2, 1::3, :]) + + assert dpt.allclose(res, buf[::-2, 1::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 0::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 2::3, :]) + + m1_np = dpt.asnumpy(m1) + ref = np.matmul(m1_np, dpt.asnumpy(m2)) + assert np.allclose(ref, dpt.asnumpy(res)) + + res = dpt.matmul(m1[:, :10, :10], m1[:, :10, :10].mT, out=m1[:, :10, :10]) + ref = np.matmul( + m1_np[:, :10, :10], np.transpose(m1_np[:, :10, :10], (0, 2, 1)) + ) + assert np.allclose(ref, dpt.asnumpy(res)) + + +def test_matmul_dtype(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int32, dpt.int16), + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.float32), + ]: + m1 = dpt.ones((10, 10), dtype=dt1) + m2 = dpt.ones((10, 10), dtype=dt2) + + for ord in ["C", "A", "F", "K"]: + r = dpt.matmul(m1, m2, dtype=dpt.float32, order=ord) + assert r.dtype == dpt.float32 + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +@pytest.mark.parametrize("order", ["C", "K"]) +def test_matmul_type_promotion(dt1, dt2, order): + get_queue_or_skip() + + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + b, n, k, m = 8, 10, 17, 10 + m1 = dpt.ones((1, n, k), dtype=dt1) + m2 = dpt.ones((b, k, m), dtype=dt2) + expected_dt = dpt.result_type(m1, m2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((b, n, k), dtype=dt1) + m2 = dpt.ones((1, k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((n, k), dtype=dt1) + m2 = dpt.ones((k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (n, m) + assert r.dtype == expected_dt + + +def test_matmul_invalid_dtype(): + get_queue_or_skip() + + m1 = dpt.zeros((10, 10), dtype="f4") + m2 = dpt.zeros((10, 10), dtype="f4") + m3 = dpt.zeros((10, 10), dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m3, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m3, m1, dtype="i4") + + +def test_matmul_out_errors(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = (10, 10) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + + with pytest.raises(TypeError): + dpt.matmul(m1, m2, out=dict()) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty((10,), dtype=dt, sycl_queue=q1)) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype="f4", sycl_queue=q1)) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype=dt, sycl_queue=q2)) + + +def test_matmul_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + sh2 = tuple(2 * dim for dim in sh) + n = sh[-1] + + for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]): + ar1 = dpt.ones(sh, dtype=dt1, order="C") + ar2 = dpt.ones(sh, dtype=dt2, order="C") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt1, order="F") + ar2 = dpt.ones(sh, dtype=dt2, order="F") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2] + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2] + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (n, -1) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2].mT + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2].mT + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (-1, n) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + +def test_matmul_invalid_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + dt = "i4" + + ar1 = dpt.ones(sh, dtype=dt, order="C") + ar2 = dpt.ones(sh, dtype=dt, order="C") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt, order="F") + ar2 = dpt.ones(sh, dtype=dt, order="F") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.f_contiguous + + +def test_matmul_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = ( + 10, + 10, + ) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2) + + +def test_matmul_inplace_broadcasting(): + get_queue_or_skip() + + sh = (3, 5, 5) + dt = "i4" + + m1 = dpt.ones((3, 5, 5), dtype=dt) + m2 = dpt.ones((1, 5, 5), dtype=dt) + m1 @= m2 + assert dpt.all(m1 == dpt.full(sh, 5, dtype=dt)) + + +def test_matmul_prepend_dims(): + get_queue_or_skip() + + n = 5 + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), + ]: + m = dpt.ones((n, 4), dtype=dt1) + v = dpt.ones((4,), dtype=dt2) + r = dpt.matmul(m, v) + assert r.shape == (n,) + + r = dpt.matmul(v, m.mT) + assert r.shape == (n,) + + +def test_matmul_inplace_same_tensors(): + get_queue_or_skip() + + n = 5 + sh = ( + n, + n, + ) + + ar1 = dpt.ones(sh, dtype="i4") + ar1 @= ar1 + assert dpt.all(ar1 == dpt.full(sh, n, dtype="i4")) + + ar1 = dpt.ones(sh, dtype="i8") + ar2 = dpt.ones(sh, dtype="i4") + dpt.matmul(ar1, ar2, out=ar1) + assert dpt.all(ar1 == dpt.full(sh, n, dtype=ar1.dtype)) + + ar1 = dpt.ones(sh, dtype="i4") + ar2 = dpt.ones(sh, dtype="i8") + dpt.matmul(ar1, ar2, out=ar2) + assert dpt.all(ar2 == dpt.full(sh, n, dtype=ar2.dtype)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_outer(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 12), dtype=dtype) + + r = dpt.tensordot(t1, t2, axes=0) + assert r.shape == t1.shape + t2.shape + assert dpt.allclose(r, dpt.ones_like(r)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_inner(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 8), dtype=dtype) + + r = dpt.tensordot(t1, t2.mT, axes=1) + assert r.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(r, dpt.full_like(r, fill_value=t1.shape[1])) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_double(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((2, 4, 8), dtype=dtype) + t2 = dpt.ones((3, 4, 8), dtype=dtype) + + r = dpt.tensordot(t1, dpt.permute_dims(t2, (1, 2, 0)), axes=2) + assert r.shape == t1.shape[:1] + t2.shape[:1] + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + assert dpt.allclose(r, dpt.full_like(r, fill_value=expected)) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_tensordot_axes_sequence(dtype): + get_queue_or_skip() + + r = 4 + t1 = dpt.ones((2, 2, 4, 3), dtype=dtype) + t2 = dpt.ones((3, 2, 4, 3), dtype=dtype) + + assert len(t1.shape) == r + assert len(t2.shape) == r + + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + ps1 = itertools.permutations(range(r)) + ps2 = itertools.permutations(range(r)) + + for p1 in ps1: + assert len(p1) == r + inv_p1 = sorted(range(r), key=p1.__getitem__) + u1 = dpt.permute_dims(t1, p1) + x1_axes = inv_p1[1:] + for p2 in ps2: + inv_p2 = sorted(range(r), key=p2.__getitem__) + u2 = dpt.permute_dims(t2, p2) + x2_axes = inv_p2[1:] + + tdr = dpt.tensordot(u1, u2, axes=(x1_axes, x2_axes)) + assert tdr.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(tdr, dpt.full_like(tdr, fill_value=expected)) + + +def test_tensordot_validation(): + get_queue_or_skip() + + with pytest.raises(TypeError): + dpt.tensordot(dict(), dict()) + + t1 = dpt.empty((10, 10, 10)) + with pytest.raises(TypeError): + dpt.tensordot(t1, dict()) + + t2 = dpt.empty((10, 10, 10)) + q = dpctl.SyclQueue(t2.sycl_context, t2.sycl_device, property="in_order") + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.tensordot(t1, t2.to_device(q)) + + invalid_axes = ( + 1, + 2, + 3, + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = 5.2 + with pytest.raises(TypeError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = ( + (1,), + ( + 0, + 2, + ), + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + with pytest.raises(ValueError): + dpt.tensordot(t1[..., :5], t2) + + +def test_tensordot_promotion(): + get_queue_or_skip() + + t1 = dpt.zeros((10, 10), dtype="i4") + t2 = dpt.zeros((10, 10), dtype="i8") + + r1 = dpt.tensordot(t1, t2) + assert r1.dtype == t2.dtype + + r2 = dpt.tensordot(t2, t1) + assert r2.dtype == t2.dtype + + t3 = dpt.zeros((10, 10), dtype="u4") + r3 = dpt.tensordot(t1, t3) + assert r3.dtype == dpt.result_type(t1, t3) + + +def test_tensordot_axes_errors(): + get_queue_or_skip() + + m1 = dpt.zeros((10, 10), dtype="i4") + m2 = dpt.zeros((10, 10), dtype="i4") + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=-1) + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((-1,), (1,))) + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((1,), (-1,))) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_1d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 511 + v1 = dpt.ones(n, dtype=dtype) + + v2 = dpt.ones(n, dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r == n + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_3d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, m2, n), dtype=dtype) + + v2 = dpt.ones((m1, m2, n), dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_axis(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, n, m2), dtype=dtype) + + v2 = dpt.ones((m1, n, m2), dtype=dtype) + + r = dpt.vecdot(v1, v2, axis=1) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + list1 = [1, 0, 2, 0] + pattern1 = dpt.asarray(list1, dtype=dtype) + n_padded1 = pattern1.size * (1 + ((n - 1) // pattern1.size)) + v1 = dpt.tile(dpt.reshape(pattern1, (1, -1, 1)), (m1, n_padded1, m2))[ + ::-1, :n, : + ] + + list2 = [1, 2, 1, 2] + pattern2 = dpt.asarray(list2, dtype=dtype) + n_padded2 = pattern2.size * (1 + ((n - 1) // pattern2.size)) + v2 = dpt.tile(dpt.reshape(pattern2, (1, -1, 1)), (m1, n_padded2, m2))[ + :, :n, ::-1 + ] + + r = dpt.vecdot(v1, v2, axis=1) + + ref = sum( + el1 * el2 + for el1, el2 in zip((list1 * n_padded1)[:n], (list2 * n_padded1)[:n]) + ) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == ref) + + +def test_vector_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.vecdot(s1, v2) + + with pytest.raises(ValueError): + dpt.vecdot(v1, s2) + + with pytest.raises(TypeError): + dpt.vecdot(dict(), v2) + + with pytest.raises(TypeError): + dpt.vecdot(v2, None) + + with pytest.raises(ValueError): + dpt.vecdot(v1[:5], v2[:4]) + + with pytest.raises(ValueError): + dpt.vecdot(v1, v2, axis=2) + + q = dpctl.SyclQueue( + v2.sycl_context, v2.sycl_device, property="enable_profiling" + ) + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.vecdot(v1, v2.to_device(q)) + + m1 = dpt.empty((10, 5)) + m2 = dpt.empty((5, 5)) + with pytest.raises(ValueError): + dpt.vecdot(m1, m2, axis=-1) + + +def test_vecdot_broadcast(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), + ]: + m1 = dpt.zeros((1, 5), dtype=dt1) + m2 = dpt.zeros((5, 5), dtype=dt2) + r1 = dpt.vecdot(m1, m2, axis=-1) + r2 = dpt.vecdot(m2, m1, axis=-1) + assert r1.shape == r2.shape + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +def test_vecdot_type_promotion(dt1, dt2): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + v1 = dpt.ones(128, dtype=dt1) + v2 = dpt.ones(128, dtype=dt2) + + r = dpt.vecdot(v1, v2) + mul = v1 * v2 + assert r.shape == tuple() + assert r.dtype == mul.dtype + assert dpt.allclose(r, dpt.sum(mul, dtype=mul.dtype)) diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index e76ac667e1..4c6240bb12 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -97,6 +97,7 @@ def test_sort_2d(dtype): def test_sort_strides(): + get_queue_or_skip() fl = dpt.roll( dpt.concat((dpt.ones(10000, dtype="i4"), dpt.zeros(10000, dtype="i4"))), diff --git a/dpctl/tests/test_usm_ndarray_unique.py b/dpctl/tests/test_usm_ndarray_unique.py index 0d4247def5..1095836996 100644 --- a/dpctl/tests/test_usm_ndarray_unique.py +++ b/dpctl/tests/test_usm_ndarray_unique.py @@ -295,10 +295,29 @@ def test_set_function_outputs(): def test_set_functions_compute_follows_data(): # tests that all intermediate calls and allocations # are compatible with an input with an arbitrary queue + get_queue_or_skip() q = dpctl.SyclQueue() x = dpt.arange(10, dtype="i4", sycl_queue=q) - assert isinstance(dpt.unique_values(x), dpctl.tensor.usm_ndarray) - assert dpt.unique_counts(x) - assert dpt.unique_inverse(x) - assert dpt.unique_all(x) + uv = dpt.unique_values(x) + assert isinstance(uv, dpctl.tensor.usm_ndarray) + assert uv.sycl_queue == q + uv, uc = dpt.unique_counts(x) + assert isinstance(uv, dpctl.tensor.usm_ndarray) + assert isinstance(uc, dpctl.tensor.usm_ndarray) + assert uv.sycl_queue == q + assert uc.sycl_queue == q + uv, inv_ind = dpt.unique_inverse(x) + assert isinstance(uv, dpctl.tensor.usm_ndarray) + assert isinstance(inv_ind, dpctl.tensor.usm_ndarray) + assert uv.sycl_queue == q + assert inv_ind.sycl_queue == q + uv, ind, inv_ind, uc = dpt.unique_all(x) + assert isinstance(uv, dpctl.tensor.usm_ndarray) + assert isinstance(ind, dpctl.tensor.usm_ndarray) + assert isinstance(inv_ind, dpctl.tensor.usm_ndarray) + assert isinstance(uc, dpctl.tensor.usm_ndarray) + assert uv.sycl_queue == q + assert ind.sycl_queue == q + assert inv_ind.sycl_queue == q + assert uc.sycl_queue == q