gnu: Add qnnpack-pytorch.

This is an internal fork of QNNPACK in the PyTorch source tree.

* gnu/packages/machine-learning.scm (%python-pytorch-version): New variable.
(%python-pytorch-src): New variable.
(qnnpack-pytorch): New variable.
This commit is contained in:
David Elsing 2024-03-23 22:05:04 +00:00 committed by Ludovic Courtès
parent 0c313244e0
commit b77c772a3e
No known key found for this signature in database
GPG key ID: 090B11993D9AEBB5

View file

@ -4334,6 +4334,133 @@ (define-public ideep-pytorch
PyTorch.")
(license license:expat)))
(define %python-pytorch-version "2.2.1")
(define %python-pytorch-src
(origin
(method git-fetch)
(uri (git-reference
(url "https://github.com/pytorch/pytorch")
(commit (string-append "v" %python-pytorch-version))))
(file-name (git-file-name "python-pytorch" %python-pytorch-version))
(sha256
(base32
"03mm0pwwb5lxdsmmiw3cch9fijgjw81kmmc4ln9rlyazkm7l1r48"))
(modules '((guix build utils)))
(snippet
'(begin
;; Bundled or unused code
(for-each
(lambda (dir)
(when (file-exists? dir)
(delete-file-recursively dir)))
'("android"
"aten/src/ATen/native/cuda/cutlass_extensions"
"aten/src/ATen/native/quantized/cpu/qnnpack"
"caffe2/mobile/contrib/libopencl-stub"
"caffe2/mobile/contrib/libvulkan-stub"
"third_party"))
;; Autogenerated files
(for-each
delete-file
'("aten/src/ATen/nnapi/nnapi_wrapper.cpp"
"aten/src/ATen/nnapi/nnapi_wrapper.h"
"caffe2/mobile/contrib/ios/mpscnn/mpscnn_kernels.h"
"caffe2/proto/caffe2_legacy_pb2.pyi"
"caffe2/proto/caffe2_pb2.pyi"
"caffe2/proto/hsm_pb2.pyi"
"caffe2/proto/metanet_pb2.pyi"
"caffe2/proto/predictor_consts_pb2.pyi"
"caffe2/proto/prof_dag_pb2.pyi"
"caffe2/proto/torch_pb2.pyi"
;; These files contain just lists of floating point values and
;; might be as well hand-written.
;; "test/cpp/api/init_baseline.h"
;; "test/cpp/api/optim_baseline.h"
"test/mobile/test_upgrader_bytecode_table_example.cpp"
"torch/csrc/jit/mobile/upgrader_mobile.cpp"
"torch/csrc/jit/runtime/decomposition_registry_util.cpp"
"torch/csrc/jit/runtime/serialized_shape_function_registry.cpp"
"torch/csrc/jit/tensorexpr/external_functions_codegen.cpp"
"torch/csrc/jit/serialization/mobile_bytecode_generated.h"))
(delete-file-recursively ".github")
(for-each
(lambda (dir)
(for-each
delete-file
(find-files dir "\\.cu$")))
'("aten/src/ATen/native/transformers/cuda/flash_attn/kernels"
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels"))))))
(define-public qnnpack-pytorch
(package
(inherit qnnpack)
(name "qnnpack-pytorch")
(version (string-append "pytorch-" %python-pytorch-version))
(source
(origin
(inherit %python-pytorch-src)
(patches '())
(modules '((guix build utils)
(srfi srfi-26)
(ice-9 ftw)))
(snippet
'(begin
(rename-file "aten/src/ATen/native/quantized/cpu/qnnpack"
"../qnnpack")
(let ((outdir (getcwd)))
(chdir "..")
(rename-file outdir "dummy")
(rename-file "qnnpack" outdir)
(chdir outdir)
(delete-file-recursively "deps"))))))
(arguments
(substitute-keyword-arguments (package-arguments qnnpack)
((#:phases phases #~%standard-phases)
#~(modify-phases %standard-phases
(add-after 'unpack 'patch-cmake
(lambda _
(substitute* "CMakeLists.txt"
(("project\\(.*" orig)
(apply
string-append
orig "\n"
(map (lambda (name)
(string-append
"option(" name " \"\" ON)\n"))
'("USE_SYSTEM_CPUINFO" "USE_SYSTEM_FP16" "USE_SYSTEM_FXDIV"
"USE_SYSTEM_PSIMD" "USE_SYSTEM_PTHREADPOOL"))))
(("if.*SOURCE_DIR.*")
"if(FALSE)\n")
(("if\\(NOT TARGET (clog|gtest|benchmark).*")
"if(FALSE)\n")
(("target_link_libraries.*(fxdiv|psimd|fp16)\\).*")
"")
(("(target_link_libraries.*) fp16 (.*)" _ before after)
(string-append before " " after)))))
(add-after 'unpack 'fix-cstring-include
(lambda _
(substitute* "include/pack_block_sparse.h"
(("#include.*<vector>.*" orig)
(string-append orig "\n#include <cstring>\n")))))
(add-after 'install 'install-missing-headers
(lambda _
(for-each
(lambda (name)
(install-file (string-append "../source/include/" name)
(string-append #$output "/include")))
'("pack_block_sparse.h"
"pytorch_qnnpack.h"
"qnnpack_func.h"))
(copy-recursively
"../source/src/qnnpack"
(string-append #$output "/include/qnnpack"))))))
;; Some tests occasionally fail on i686 due to floating point rounding.
((#:tests? _ #t)
(not (string-prefix? "i686" (or (%current-target-system)
(%current-system)))))))))
;; Please also update python-torchvision when updating this package.
(define-public python-pytorch
(package