Halide  13.0.2
Halide compiler and libraries
HalidePyTorchHelpers.h
Go to the documentation of this file.
1 #ifndef HL_PYTORCH_WRAPPER_H
2 #define HL_PYTORCH_WRAPPER_H
3 
4 /** \file
5  * Set of utility functions to wrap PyTorch tensors into Halide buffers,
6  * making sure the data in on the correct device (CPU/GPU). This header
7  * is included in each generated op by the PyTorch CodeGen.
8  */
9 
10 #include <exception>
11 #include <iostream>
12 #include <sstream>
13 #include <string>
14 #include <vector>
15 
16 #include "HalideBuffer.h"
17 
18 #include "HalideRuntimeCuda.h"
19 
20 #define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
21 #define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
22 #define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
23 
24 namespace Halide {
25 namespace PyTorch {
26 
28 
29 inline std::vector<int> get_dims(const at::Tensor tensor) {
30  int ndims = tensor.ndimension();
31  std::vector<int> dims(ndims, 0);
32  // PyTorch dim order is reverse of Halide
33  for (int dim = 0; dim < ndims; ++dim) {
34  dims[dim] = tensor.size(ndims - 1 - dim);
35  }
36  return dims;
37 }
38 
39 template<class scalar_t>
40 inline void check_type(at::Tensor &tensor) {
41  AT_ERROR("Scalar type ", tensor.scalar_type(), " not handled by Halide's PyTorch wrapper");
42 }
43 
44 // TODO: if PyTorch exposes any variable with the API version,
45 // I haven't found it in source or documentation; for now, we'll sniff
46 // this macro's existence to infer that we are building with v1.3+ (vs 1.2)
47 #ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
48 #define HL_PYTORCH_API_VERSION 13
49 #else
50 #define HL_PYTORCH_API_VERSION 12
51 #endif
52 
53 #if HL_PYTORCH_API_VERSION >= 13
54 
55 // PyTorch 1.3+
56 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
57  template<> \
58  inline void check_type<ctype>(at::Tensor & tensor) { \
59  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
60  }
61 
62 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(HL_PT_DEFINE_TYPECHECK);
63 
64 #undef HL_PT_DEFINE_TYPECHECK
65 
66 #else // HL_PYTORCH_API_VERSION < 13
67 
68 // PyTorch 1.2
69 
70 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
71  template<> \
72  inline void check_type<ctype>(at::Tensor & tensor) { \
73  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
74  }
75 
77 
78 #undef HL_PT_DEFINE_TYPECHECK
79 
80 #endif // HL_PYTORCH_API_VERSION check
81 
82 template<class scalar_t>
83 inline Buffer<scalar_t> wrap(at::Tensor &tensor) {
84  check_type<scalar_t>(tensor);
85  std::vector<int> dims = get_dims(tensor);
86 #if HL_PYTORCH_API_VERSION >= 13
87  scalar_t *pData = tensor.data_ptr<scalar_t>();
88 #else
89  scalar_t *pData = tensor.data<scalar_t>();
90 #endif
91  Buffer<scalar_t> buffer;
92 
93  // TODO(mgharbi): force Halide to put input/output on GPU?
94  if (tensor.is_cuda()) {
95  buffer = Buffer<scalar_t>(dims);
97  int err = buffer.device_wrap_native(cuda_interface, (uint64_t)pData);
98  AT_ASSERTM(err == 0, "(CUDA) halide_device_wrap failed");
99  buffer.set_device_dirty();
100  } else {
101  buffer = Buffer<scalar_t>(pData, dims);
102  }
103 
104  return buffer;
105 }
106 
107 } // namespace PyTorch
108 } // namespace Halide
109 
110 #endif // HL_PYTORCH_WRAPPER_H
Defines a Buffer type that wraps from halide_buffer_t and adds functionality, and methods for more co...
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
Routines specific to the Halide Cuda runtime.
const struct halide_device_interface_t * halide_cuda_device_interface()
A templated Buffer class that wraps halide_buffer_t and adds functionality.
Definition: HalideBuffer.h:131
Buffer< scalar_t > wrap(at::Tensor &tensor)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(HL_PT_DEFINE_TYPECHECK)
std::vector< int > get_dims(const at::Tensor tensor)
void check_type(at::Tensor &tensor)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
unsigned __INT64_TYPE__ uint64_t
Each GPU API provides a halide_device_interface_t struct pointing to the code that manages device all...