Deep Neural Network Library (DNNL)  1.90.1
Performance library for Deep Learning
dnnl.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
19 
20 #ifndef DNNL_HPP
21 #define DNNL_HPP
22 
23 #include "dnnl_config.h"
24 
26 #include <algorithm>
27 #include <cstdlib>
28 #include <iterator>
29 #include <memory>
30 #include <vector>
31 #include <unordered_map>
32 
33 #include "dnnl.h"
34 
35 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
36 #include <CL/cl.h>
37 #endif
38 
39 #if DNNL_WITH_SYCL
40 #include <CL/sycl.hpp>
41 #endif
42 
44 
45 namespace dnnl {
46 
49 
52 
57 struct error : public std::exception {
58  dnnl_status_t status;
59  const char *message;
60 
65  error(dnnl_status_t astatus, const char *amessage)
66  : status(astatus), message(amessage) {}
67 
69  const char *what() const noexcept override { return message; }
70 
76  static void wrap_c_api(dnnl_status_t status, const char *message) {
77  if (status != dnnl_success) throw error(status, message);
78  }
79 };
80 
82 template <typename T>
83 class handle_traits {};
84 
98 template <typename T, typename traits = handle_traits<T>>
99 class handle {
100 private:
101  static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
102 
103  std::shared_ptr<typename std::remove_pointer<T>::type> _data {0};
104 
105 protected:
106  bool operator==(const T other) const { return other == _data.get(); }
107  bool operator!=(const T other) const { return !(*this == other); }
108 
109 public:
119  handle() = default;
120  handle(const handle<T, traits> &) = default;
121  handle(handle<T, traits> &&) = default;
122  handle<T, traits> &operator=(handle<T, traits> &&) = default;
123  handle<T, traits> &operator=(const handle<T, traits> &) = default;
124 
128  explicit handle(T t, bool weak = false) { reset(t, weak); }
129 
133  void reset(T t, bool weak = false) {
134  _data.reset(t, weak ? &dummy_destructor : traits::destructor);
135  }
136 
138  T get(bool allow_emtpy = false) const {
139  T result = _data.get();
140 
141  if (allow_emtpy == false && result == nullptr)
143  "attempt to use uninitialized object");
144 
145  return result;
146  }
147 
148  explicit operator T() const { return get(true); }
149 
150  explicit operator bool() const { return get(true) != nullptr; }
151 
152  bool operator==(const handle &other) const {
153  return other._data.get() == _data.get();
154  }
155  bool operator!=(const handle &other) const { return !(*this == other); }
156 };
157 
159 template <>
161  static constexpr auto destructor = &dnnl_memory_destroy;
162 };
163 
164 template <>
165 struct handle_traits<dnnl_primitive_desc_t> {
166  static constexpr auto destructor = &dnnl_primitive_desc_destroy;
167 };
168 
169 template <>
170 struct handle_traits<dnnl_primitive_t> {
171  static constexpr auto destructor = &dnnl_primitive_destroy;
172 };
173 
174 template <>
175 struct handle_traits<dnnl_primitive_desc_iterator_t> {
176  static constexpr auto destructor = &dnnl_primitive_desc_iterator_destroy;
177 };
179 
180 struct stream;
181 struct error;
182 struct memory;
183 struct primitive_desc;
184 
186 class primitive : public handle<dnnl_primitive_t> {
187  friend struct error;
188  friend struct stream;
189  using handle::handle;
190 
191 public:
194  enum class kind {
204  sum = dnnl_sum,
216  lrn = dnnl_lrn,
224  rnn = dnnl_rnn,
227  };
228 
229  primitive(const_dnnl_primitive_desc_t c_pd);
230  primitive(const primitive_desc &pd);
231 
234  // TODO: use the C++ API wrapper structure.
235 
236  void execute(
237  stream &astream, const std::unordered_map<int, memory> &args) const;
238 
239 #ifdef DNNL_SYCL_DPCPP
240  cl::sycl::event DNNL_API execute_sycl(stream &astream,
241  const std::unordered_map<int, memory> &args,
242  const std::vector<cl::sycl::event> &deps = {}) const;
243 #endif
244 };
245 
246 inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
247  return static_cast<dnnl_primitive_kind_t>(akind);
248 }
249 
253  "could not get primitive descriptor by primitive");
254  return pd;
255 }
257 
262 
264 enum class scratchpad_mode {
269 };
270 
271 inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
272  return static_cast<dnnl_scratchpad_mode_t>(mode);
273 }
274 
276 enum class prop_kind {
300 };
301 
302 inline dnnl_prop_kind_t convert_to_c(prop_kind kind) {
303  return static_cast<dnnl_prop_kind_t>(kind);
304 }
305 
307 enum class algorithm {
308  undef = dnnl_alg_kind_undef,
353  pooling_avg = dnnl_pooling_avg,
376 };
377 
378 inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
379  return static_cast<dnnl_alg_kind_t>(aalgorithm);
380 }
381 
383 enum class normalization_flags : unsigned {
396 
411 
420 };
421 
422 inline dnnl_normalization_flags_t convert_to_c(normalization_flags aflag) {
423  return static_cast<dnnl_normalization_flags_t>(aflag);
424 }
425 
426 enum class rnn_flags : unsigned { undef = dnnl_rnn_flags_undef };
427 
428 inline dnnl_rnn_flags_t convert_to_c(rnn_flags aflag) {
429  return static_cast<dnnl_rnn_flags_t>(aflag);
430 }
431 
432 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \
433  inline enum_name operator|(enum_name lhs, enum_name rhs) { \
434  return static_cast<enum_name>( \
435  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
436  } \
437 \
438  inline enum_name operator&(enum_name lhs, enum_name rhs) { \
439  return static_cast<enum_name>( \
440  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
441  } \
442 \
443  inline enum_name operator^(enum_name lhs, enum_name rhs) { \
444  return static_cast<enum_name>( \
445  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
446  } \
447 \
448  inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
449  lhs = static_cast<enum_name>( \
450  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
451  return lhs; \
452  } \
453 \
454  inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
455  lhs = static_cast<enum_name>( \
456  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
457  return lhs; \
458  } \
459 \
460  inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
461  lhs = static_cast<enum_name>( \
462  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
463  return lhs; \
464  } \
465 \
466  inline enum_name operator~(enum_name rhs) { \
467  return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
468  }
469 
470 DNNL_DEFINE_BITMASK_OPS(normalization_flags)
471 DNNL_DEFINE_BITMASK_OPS(rnn_flags)
472 
473 #undef DNNL_DEFINE_BITMASK_OPS
474 
475 enum class rnn_direction {
476  unidirectional_left2right = dnnl_unidirectional_left2right,
477  unidirectional_right2left = dnnl_unidirectional_right2left,
478  unidirectional = dnnl_unidirectional,
479  bidirectional_concat = dnnl_bidirectional_concat,
480  bidirectional_sum = dnnl_bidirectional_sum,
481 };
482 
483 inline dnnl_rnn_direction_t convert_to_c(rnn_direction adir) {
484  return static_cast<dnnl_rnn_direction_t>(adir);
485 }
486 
494 enum class query {
497 
502 
507 
516 
521 
526 
529 
556 
573 };
574 
575 inline dnnl_query_t convert_to_c(query aquery) {
576  return static_cast<dnnl_query_t>(aquery);
577 }
578 
580 
586 
588 template <>
589 struct handle_traits<dnnl_post_ops_t> {
590  static constexpr auto destructor = &dnnl_post_ops_destroy;
591 };
593 
597 struct post_ops : public handle<dnnl_post_ops_t> {
599 
601  post_ops() {
602  dnnl_post_ops_t result;
604  "could not create post operation sequence");
605  reset(result);
606  }
607 
609  int len() const { return dnnl_post_ops_len(get()); }
610 
612  primitive::kind kind(int index) const {
614  "post_ops index is out of range");
615  return static_cast<primitive::kind>(
616  dnnl_post_ops_get_kind(get(), index));
617  }
618 
639  void append_sum(float scale = 1.) {
641  dnnl_post_ops_append_sum(get(), scale), "could not append sum");
642  }
643 
646  void get_params_sum(int index, float &scale) const {
648  "could not get sum params");
649  }
650 
659  void append_eltwise(float scale, algorithm alg, float alpha, float beta) {
661  get(), scale, convert_to_c(alg), alpha, beta),
662  "could not append eltwise");
663  }
664 
666  void get_params_eltwise(int index, float &scale, algorithm &alg,
667  float &alpha, float &beta) const {
668  dnnl_alg_kind_t c_alg;
670  get(), index, &scale, &c_alg, &alpha, &beta),
671  "could not get eltwise params");
672  alg = static_cast<algorithm>(c_alg);
673  }
674 };
675 
677 template <>
678 struct handle_traits<dnnl_primitive_attr_t> {
679  static constexpr auto destructor = &dnnl_primitive_attr_destroy;
680 };
682 
686 struct primitive_attr : public handle<dnnl_primitive_attr_t> {
688 
690  primitive_attr() {
691  dnnl_primitive_attr_t result;
693  "could not create a primitive attr");
694  reset(result);
695  }
696 
701  : handle<dnnl_primitive_attr_t>(attr) {}
702 
705  dnnl_scratchpad_mode_t result;
708  "could not get scratchpad mode");
709  return scratchpad_mode(result);
710  }
711 
715  get(), dnnl::convert_to_c(mode)),
716  "could not set scratchpad mode");
717  }
718 
721  void get_output_scales(int &mask, std::vector<float> &scales) const {
722  dnnl_dim_t count;
723  int c_mask;
724  const float *c_scales;
726  get(), &count, &c_mask, &c_scales),
727  "could not get int output scales");
728  scales.resize(count);
729 
730  mask = c_mask;
731  for (dnnl_dim_t c = 0; c < count; ++c)
732  scales[c] = c_scales[c];
733  }
734 
750  void set_output_scales(int mask, const std::vector<float> &scales) {
752  (dnnl_dim_t)scales.size(), mask, &scales[0]),
753  "could not set int output scales");
754  }
755 
757  const post_ops get_post_ops() const {
758  post_ops result;
759  const_dnnl_post_ops_t c_result;
761  "could not get post operation sequence");
762  result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
763  return result;
764  }
765 
767  void set_post_ops(post_ops ops) {
769  "could not set post operation sequence");
770  }
771 
780  void set_rnn_data_qparams(float scale, float shift) {
783  "could not set rnn data int scale/shift");
784  }
785 
809  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
811  get(), (int)scales.size(), mask, &scales[0]),
812  "could not set rnn weights int scales");
813  }
814 };
815 
817 
823 
825 template <>
826 struct handle_traits<dnnl_engine_t> {
827  static constexpr auto destructor = &dnnl_engine_destroy;
828 };
830 
832 struct engine : public handle<dnnl_engine_t> {
833  friend class primitive;
834  friend struct reorder;
835 
837  enum class kind {
841  cpu = dnnl_cpu,
843  gpu = dnnl_gpu,
844  };
845 
846  engine() = default;
847 
851  static size_t get_count(kind akind) {
852  return dnnl_engine_get_count(convert_to_c(akind));
853  }
854 
861  engine(kind akind, size_t index) {
862  dnnl_engine_t aengine;
864  dnnl_engine_create(&aengine, convert_to_c(akind), index),
865  "could not create an engine");
866  reset(aengine);
867  }
868 
869 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
870  engine(kind akind, cl_device_id device, cl_context context) {
873  dnnl_engine_t aengine;
874  error::wrap_c_api(dnnl_engine_create_ocl(&aengine, convert_to_c(akind),
875  device, context),
876  "could not create an engine");
877  reset(aengine);
878  }
879 #endif
880 
881 #if DNNL_WITH_SYCL
882  DNNL_API engine(kind akind, const cl::sycl::device &dev,
888  const cl::sycl::context &ctx);
889 #endif
890 
892  explicit engine(const dnnl_engine_t &aengine) : handle(aengine, true) {}
893 
896  engine(const handle<dnnl_primitive_desc_t> &pd) {
897  dnnl_engine_t engine_q;
899  dnnl_primitive_desc_query(pd.get(),
900  dnnl::convert_to_c(dnnl::query::engine), 0, &engine_q),
901  "could not get engine from primitive_desc");
902  reset(engine_q, true);
903  }
904 
906  kind get_kind() const {
907  dnnl_engine_kind_t akind;
909  "could not get the engine kind");
910  return static_cast<engine::kind>(akind);
911  }
912 
913 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
914  cl_context get_ocl_context() const {
916  cl_context context = nullptr;
918  "could not get a context handle");
919  return context;
920  }
921 
923  cl_device_id get_ocl_device() const {
924  cl_device_id device = nullptr;
926  "could not get a device handle");
927  return device;
928  }
929 #endif
930 
931 #if DNNL_WITH_SYCL
932  cl::sycl::context DNNL_API get_sycl_context() const;
934 
936  cl::sycl::device DNNL_API get_sycl_device() const;
937 #endif
938 
939  template <class primitive_desc>
940  static engine query(const primitive_desc &pd) {
941  return query(pd, dnnl::query::engine);
942  }
943 
944 private:
945  static dnnl_engine_kind_t convert_to_c(kind akind) {
946  return static_cast<dnnl_engine_kind_t>(akind);
947  }
948 
949  template <class primitive_desc>
950  static engine query(const primitive_desc &pd, dnnl::query what) {
951  dnnl_engine_t engine_q;
953  dnnl::convert_to_c(what), 0, &engine_q),
954  "could not get engine from primitive_desc");
955 
956  return engine(engine_q);
957  }
958 };
959 
961 
967 
969 template <>
970 struct handle_traits<dnnl_stream_t> {
971  static constexpr auto destructor = &dnnl_stream_destroy;
972 };
974 
976 struct stream : public handle<dnnl_stream_t> {
977  using handle::handle;
978 
980  enum class flags : unsigned {
990  };
991 
992  stream() = default;
993 
995  stream(const engine &aengine, flags aflags = flags::default_flags) {
996  dnnl_stream_t astream;
997  error::wrap_c_api(dnnl_stream_create(&astream, aengine.get(),
998  static_cast<dnnl_stream_flags_t>(aflags)),
999  "could not create a stream");
1000  reset(astream);
1001  }
1002 
1003 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
1004  stream(const engine &eng, cl_command_queue queue) {
1007  dnnl_stream_t astream;
1008  error::wrap_c_api(dnnl_stream_create_ocl(&astream, eng.get(), queue),
1009  "could not create a stream");
1010  reset(astream);
1011  }
1012 
1014  cl_command_queue get_ocl_command_queue() const {
1015  cl_command_queue queue = nullptr;
1017  "could not get OpenCL command queue");
1018  return queue;
1019  }
1020 #endif
1021 
1022 #if DNNL_WITH_SYCL
1023  DNNL_API stream(const engine &eng, cl::sycl::queue &aqueue);
1028 
1030  cl::sycl::queue DNNL_API get_sycl_queue() const;
1031 #endif
1032 
1034  stream &wait() {
1035  error::wrap_c_api(dnnl_stream_wait(get()), "could not wait a stream");
1036  return *this;
1037  }
1038 };
1039 
1040 inline stream::flags operator|(stream::flags lhs, stream::flags rhs) {
1041  return static_cast<stream::flags>(
1042  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
1043 }
1044 
1045 inline stream::flags operator&(stream::flags lhs, stream::flags rhs) {
1046  return static_cast<stream::flags>(
1047  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs));
1048 }
1049 
1050 inline stream::flags operator^(stream::flags lhs, stream::flags rhs) {
1051  return static_cast<stream::flags>(
1052  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs));
1053 }
1054 
1055 inline stream::flags operator~(stream::flags rhs) {
1056  return static_cast<stream::flags>(~static_cast<unsigned>(rhs));
1057 }
1058 
1060 
1063 
1069 
1071 struct memory : public handle<dnnl_memory_t> {
1072  typedef dnnl_dim_t dim;
1073  typedef std::vector<dim> dims;
1074 
1075  template <typename T>
1076  static void validate_dims(const std::vector<T> &v) {
1077  if (v.size() > DNNL_MAX_NDIMS)
1078  throw error(dnnl_invalid_arguments, "invalid dimensions");
1079  }
1080 
1082  enum class data_type {
1086  f16 = dnnl_f16,
1088  bf16 = dnnl_bf16,
1090  f32 = dnnl_f32,
1092  s32 = dnnl_s32,
1094  s8 = dnnl_s8,
1096  u8 = dnnl_u8,
1097  };
1098 
1100  enum class format_kind {
1114  };
1115 
1118  enum class format_tag {
1124 
1125  // Semantic agnostic section
1126  // The physical order of dimensions is defined by the permutation of the
1127  // characters, assuming that ab..z defines the natural order.
1128 
1129  // Plain formats
1130 
1131  a = dnnl_a,
1132  ab = dnnl_ab,
1133  abc = dnnl_abc,
1134  abcd = dnnl_abcd,
1135  abcde = dnnl_abcde,
1136  abcdef = dnnl_abcdef,
1137 
1138  // Permuted plain formats
1139 
1140  abdec = dnnl_abdec,
1141  acb = dnnl_acb,
1142  acbde = dnnl_acbde,
1143  acdb = dnnl_acdb,
1144  acdeb = dnnl_acdeb,
1145  ba = dnnl_ba,
1146  bac = dnnl_bac,
1147  bacd = dnnl_bacd,
1148  bcda = dnnl_bcda,
1149  cba = dnnl_cba,
1150  cdba = dnnl_cdba,
1151  cdeba = dnnl_cdeba,
1152  decab = dnnl_decab,
1153 
1154  // Opaque blocked formats
1155 
1156  Abc16a = dnnl_Abc16a,
1157  ABc16a16b = dnnl_ABc16a16b,
1158  aBc16b = dnnl_aBc16b,
1159  ABc16b16a = dnnl_ABc16b16a,
1160  Abc4a = dnnl_Abc4a,
1161  aBc4b = dnnl_aBc4b,
1162  ABc4b16a4b = dnnl_ABc4b16a4b,
1163  ABc4b4a = dnnl_ABc4b4a,
1164  ABc8a16b2a = dnnl_ABc8a16b2a,
1165  ABc8a8b = dnnl_ABc8a8b,
1166  aBc8b = dnnl_aBc8b,
1167  ABc8b16a2b = dnnl_ABc8b16a2b,
1168  ABc8b8a = dnnl_ABc8b8a,
1169  Abcd16a = dnnl_Abcd16a,
1170  ABcd16a16b = dnnl_ABcd16a16b,
1171  aBcd16b = dnnl_aBcd16b,
1172  ABcd16b16a = dnnl_ABcd16b16a,
1173  aBCd16b16c = dnnl_aBCd16b16c,
1174  aBCd16c16b = dnnl_aBCd16c16b,
1175  Abcd4a = dnnl_Abcd4a,
1176  aBcd4b = dnnl_aBcd4b,
1177  ABcd4b16a4b = dnnl_ABcd4b16a4b,
1178  ABcd4b4a = dnnl_ABcd4b4a,
1179  aBCd4c16b4c = dnnl_aBCd4c16b4c,
1180  aBCd4c4b = dnnl_aBCd4c4b,
1181  ABcd8a16b2a = dnnl_ABcd8a16b2a,
1182  ABcd8a8b = dnnl_ABcd8a8b,
1184  aBcd8b = dnnl_aBcd8b,
1185  ABcd8b16a2b = dnnl_ABcd8b16a2b,
1186  aBCd8b16c2b = dnnl_aBCd8b16c2b,
1189  aBCd8b8c = dnnl_aBCd8b8c,
1190  aBCd8c16b2c = dnnl_aBCd8c16b2c,
1191  aBCd8c8b = dnnl_aBCd8c8b,
1192  Abcde16a = dnnl_Abcde16a,
1193  ABcde16a16b = dnnl_ABcde16a16b,
1194  aBcde16b = dnnl_aBcde16b,
1195  ABcde16b16a = dnnl_ABcde16b16a,
1196  aBCde16b16c = dnnl_aBCde16b16c,
1197  aBCde16c16b = dnnl_aBCde16c16b,
1198  aBCde2c8b4c = dnnl_aBCde2c8b4c,
1199  Abcde4a = dnnl_Abcde4a,
1200  aBcde4b = dnnl_aBcde4b,
1201  ABcde4b4a = dnnl_ABcde4b4a,
1202  aBCde4b4c = dnnl_aBCde4b4c,
1203  aBCde4c16b4c = dnnl_aBCde4c16b4c,
1204  aBCde4c4b = dnnl_aBCde4c4b,
1205  Abcde8a = dnnl_Abcde8a,
1206  ABcde8a8b = dnnl_ABcde8a8b,
1207  aBcde8b = dnnl_aBcde8b,
1208  ABcde8b16a2b = dnnl_ABcde8b16a2b,
1209  aBCde8b16c2b = dnnl_aBCde8b16c2b,
1210  ABcde8b8a = dnnl_ABcde8b8a,
1211  aBCde8b8c = dnnl_aBCde8b8c,
1212  ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1213  ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1214  aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1215  aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1216  aBCde8c16b2c = dnnl_aBCde8c16b2c,
1217  aBCde8c8b = dnnl_aBCde8c8b,
1218  aBcdef16b = dnnl_aBcdef16b,
1219  aBCdef16b16c = dnnl_aBCdef16b16c,
1220  aBCdef16c16b = dnnl_aBCdef16c16b,
1221  aBcdef4b = dnnl_aBcdef4b,
1222  aBCdef4c4b = dnnl_aBCdef4c4b,
1223  aBCdef8b8c = dnnl_aBCdef8b8c,
1224  aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1225  aBCdef8c8b = dnnl_aBCdef8c8b,
1226  aBdc16b = dnnl_aBdc16b,
1227  aBdc4b = dnnl_aBdc4b,
1228  aBdc8b = dnnl_aBdc8b,
1229  aBdec16b = dnnl_aBdec16b,
1230  aBdec4b = dnnl_aBdec4b,
1231  aBdec8b = dnnl_aBdec8b,
1232  aBdefc16b = dnnl_aBdefc16b,
1233  aCBdef16c16b = dnnl_aCBdef16c16b,
1234  aBdefc4b = dnnl_aBdefc4b,
1235  aBdefc8b = dnnl_aBdefc8b,
1236  Acb16a = dnnl_Acb16a,
1237  Acb4a = dnnl_Acb4a,
1238  Acb8a = dnnl_Acb8a,
1239  aCBd16b16c = dnnl_aCBd16b16c,
1240  aCBd16c16b = dnnl_aCBd16c16b,
1241  aCBde16b16c = dnnl_aCBde16b16c,
1242  aCBde16c16b = dnnl_aCBde16c16b,
1243  Acdb16a = dnnl_Acdb16a,
1244  Acdb4a = dnnl_Acdb4a,
1245  Acdb8a = dnnl_Acdb8a,
1246  Acdeb16a = dnnl_Acdeb16a,
1247  Acdeb4a = dnnl_Acdeb4a,
1248  Acdeb8a = dnnl_Acdeb8a,
1249  BAc16a16b = dnnl_BAc16a16b,
1250  BAc16b16a = dnnl_BAc16b16a,
1251  BAcd16a16b = dnnl_BAcd16a16b,
1252  BAcd16b16a = dnnl_BAcd16b16a,
1253  ABcd32a32b = dnnl_ABcd32a32b,
1254  BAcde16b16 = dnnl_BAcde16b16a,
1255  aBdec32b = dnnl_aBdec32b,
1256  Abcdef16a = dnnl_Abcdef16a,
1257  Acdb32a = dnnl_Acdb32a,
1258  format_tag_last = dnnl_format_tag_last,
1259 
1260  x = dnnl_x,
1263  nc = dnnl_nc,
1264  cn = dnnl_cn,
1265  tn = dnnl_tn,
1266  nt = dnnl_nt,
1267  ncw = dnnl_ncw,
1268  nwc = dnnl_nwc,
1271  nchw = dnnl_nchw,
1274  nhwc = dnnl_nhwc,
1277  chwn = dnnl_chwn,
1278  ncdhw = dnnl_ncdhw,
1279  ndhwc = dnnl_ndhwc,
1280  oi = dnnl_oi,
1281  io = dnnl_io,
1282  oiw = dnnl_oiw,
1283  wio = dnnl_wio,
1284  oihw = dnnl_oihw,
1285  hwio = dnnl_hwio,
1286  ihwo = dnnl_ihwo,
1287  iohw = dnnl_iohw,
1288  oidhw = dnnl_oidhw,
1289  dhwio = dnnl_dhwio,
1290  goiw = dnnl_goiw,
1291  goihw = dnnl_goihw,
1292  hwigo = dnnl_hwigo,
1293  giohw = dnnl_giohw,
1294  goidhw = dnnl_goidhw,
1295  tnc = dnnl_tnc,
1296  ntc = dnnl_ntc,
1297  ldnc = dnnl_ldnc,
1298  ldigo = dnnl_ldigo,
1299  ldgoi = dnnl_ldgoi,
1300  ldgo = dnnl_ldgo,
1301  nCdhw16c = dnnl_nCdhw16c,
1302  nCdhw4c = dnnl_nCdhw4c,
1303  nCdhw8c = dnnl_nCdhw8c,
1304  nChw16c = dnnl_nChw16c,
1305  nChw4c = dnnl_nChw4c,
1306  nChw8c = dnnl_nChw8c,
1307  nCw16c = dnnl_nCw16c,
1308  nCw4c = dnnl_nCw4c,
1309  nCw8c = dnnl_nCw8c,
1310  NCw16n16c = dnnl_NCw16n16c,
1311  NChw16n16c = dnnl_NChw16n16c,
1312  NCdhw16n16c = dnnl_NCdhw16n16c,
1313  NChw32n32c = dnnl_NChw32n32c,
1314  IOhw16i16o = dnnl_IOhw16i16o,
1315  Ohwi32o = dnnl_Ohwi32o,
1316  IOdhw16i16o = dnnl_IOdhw16i16o,
1317  gIOhw16i16o = dnnl_gIOhw16i16o,
1318  gOhwi32o = dnnl_gOhwi32o,
1319  Goidhw16g = dnnl_Goidhw16g,
1320  IOw16o16i = dnnl_IOw16o16i,
1321  OIw16i16o = dnnl_OIw16i16o,
1322  IOw16i16o = dnnl_IOw16i16o,
1323  gIOw16i16o = dnnl_gIOw16i16o,
1324  OIw16o16i = dnnl_OIw16o16i,
1325  Oiw16o = dnnl_Oiw16o,
1326  OIw4i16o4i = dnnl_OIw4i16o4i,
1327  OIw4i4o = dnnl_OIw4i4o,
1328  Oiw4o = dnnl_Oiw4o,
1329  OIw8i16o2i = dnnl_OIw8i16o2i,
1330  OIw8i8o = dnnl_OIw8i8o,
1331  OIw8o16i2o = dnnl_OIw8o16i2o,
1332  OIw8o8i = dnnl_OIw8o8i,
1333  Owi16o = dnnl_Owi16o,
1334  Owi4o = dnnl_Owi4o,
1335  Owi8o = dnnl_Owi8o,
1336  IOhw16o16i = dnnl_IOhw16o16i,
1337  Ohwi16o = dnnl_Ohwi16o,
1338  Ohwi4o = dnnl_Ohwi4o,
1339  Ohwi8o = dnnl_Ohwi8o,
1340  OIhw16i16o = dnnl_OIhw16i16o,
1341  OIhw16o16i = dnnl_OIhw16o16i,
1342  Oihw16o = dnnl_Oihw16o,
1343  OIhw4i16o4i = dnnl_OIhw4i16o4i,
1344  OIhw4i4o = dnnl_OIhw4i4o,
1345  Oihw4o = dnnl_Oihw4o,
1346  OIhw8i16o2i = dnnl_OIhw8i16o2i,
1347  OIhw8i8o = dnnl_OIhw8i8o,
1348  OIhw8o16i2o = dnnl_OIhw8o16i2o,
1349  OIhw8o8i = dnnl_OIhw8o8i,
1350  Odhwi16o = dnnl_Odhwi16o,
1351  Odhwi4o = dnnl_Odhwi4o,
1352  Odhwi8o = dnnl_Odhwi8o,
1353  OIdhw16i16o = dnnl_OIdhw16i16o,
1354  OIdhw16o16i = dnnl_OIdhw16o16i,
1355  Oidhw16o = dnnl_Oidhw16o,
1356  OIdhw4i4o = dnnl_OIdhw4i4o,
1357  Oidhw4o = dnnl_Oidhw4o,
1358  OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1359  OIdhw8i8o = dnnl_OIdhw8i8o,
1360  OIdhw8o8i = dnnl_OIdhw8o8i,
1361  gIOw16o16i = dnnl_gIOw16o16i,
1362  gOIw16i16o = dnnl_gOIw16i16o,
1363  gOIw16o16i = dnnl_gOIw16o16i,
1364  gOiw16o = dnnl_gOiw16o,
1365  gOIw4i16o4i = dnnl_gOIw4i16o4i,
1366  gOIw4i4o = dnnl_gOIw4i4o,
1367  gOiw4o = dnnl_gOiw4o,
1368  gOIw8i16o2i = dnnl_gOIw8i16o2i,
1369  gOIw8i8o = dnnl_gOIw8i8o,
1370  gOIw8o16i2o = dnnl_gOIw8o16i2o,
1371  gOIw8o8i = dnnl_gOIw8o8i,
1372  gOwi16o = dnnl_gOwi16o,
1373  gOwi4o = dnnl_gOwi4o,
1374  gOwi8o = dnnl_gOwi8o,
1375  gIOhw16o16i = dnnl_gIOhw16o16i,
1376  gOhwi16o = dnnl_gOhwi16o,
1377  gOhwi4o = dnnl_gOhwi4o,
1378  gOhwi8o = dnnl_gOhwi8o,
1379  Goihw16g = dnnl_Goihw16g,
1380  gOIhw16i16o = dnnl_gOIhw16i16o,
1381  gOIhw16o16i = dnnl_gOIhw16o16i,
1382  gOihw16o = dnnl_gOihw16o,
1383  gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1384  gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1385  gOIhw4i4o = dnnl_gOIhw4i4o,
1386  gOIhw4o4i = dnnl_gOIhw4o4i,
1387  gOihw4o = dnnl_gOihw4o,
1388  Goihw8g = dnnl_Goihw8g,
1389  gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1390  gOIhw8i8o = dnnl_gOIhw8i8o,
1391  gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1392  OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1393  OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1394  gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1395  gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1396  gOIhw8o8i = dnnl_gOIhw8o8i,
1397  gIOdhw16i16o = dnnl_gIOdhw16i16o,
1398  gOdhwi16o = dnnl_gOdhwi16o,
1399  gOdhwi4o = dnnl_gOdhwi4o,
1400  gOdhwi8o = dnnl_gOdhwi8o,
1401  gOIdhw16i16o = dnnl_gOIdhw16i16o,
1402  gOIdhw16o16i = dnnl_gOIdhw16o16i,
1403  gOidhw16o = dnnl_gOidhw16o,
1404  gOIdhw4i4o = dnnl_gOIdhw4i4o,
1405  gOidhw4o = dnnl_gOidhw4o,
1406  gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1407  gOIdhw8i8o = dnnl_gOIdhw8i8o,
1408  gOIdhw8o8i = dnnl_gOIdhw8o8i,
1409  };
1410 
1412  struct desc {
1413  friend struct memory;
1416 
1418  desc() : data() {}
1419 
1425  desc(const dims &adims, data_type adata_type, format_tag aformat_tag) {
1426  validate_dims(adims);
1428  dnnl_memory_desc_init_by_tag(&data, (int)adims.size(),
1429  adims.size() == 0 ? nullptr : &adims[0],
1430  convert_to_c(adata_type),
1431  convert_to_c(aformat_tag)),
1432  "could not initialize a memory descriptor by tag");
1433  }
1434 
1440  desc(const dims &adims, data_type adata_type, const dims &astrides) {
1441  validate_dims(adims);
1443  dnnl_memory_desc_init_by_strides(&data, (int)adims.size(),
1444  adims.size() == 0 ? nullptr : &adims[0],
1445  convert_to_c(adata_type),
1446  astrides.size() == 0 ? nullptr : &astrides[0]),
1447  "could not initialize a memory descriptor by strides");
1448  }
1449 
1453  desc(const dnnl_memory_desc_t &adata) : data(adata) {}
1454 
1456  //
1459  desc submemory_desc(const dims &adims, const dims &offsets) {
1460  dnnl_memory_desc_t sub_md;
1462  &sub_md, &data, &adims[0], &offsets[0]),
1463  "could not initialize a sub-memory");
1464  return desc(sub_md);
1465  }
1466 
1468  desc reshape(const dims &adims) {
1469  dnnl_memory_desc_t out_md;
1471  (int)adims.size(), &adims[0]),
1472  "could not reshape a memory descriptor");
1473  return desc(out_md);
1474  }
1475 
1478  size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
1479 
1481  bool is_zero() const { return data.ndims == 0; }
1482 
1483  bool operator==(const desc &other) const {
1484  return dnnl_memory_desc_equal(&data, &other.data) != 0;
1485  }
1486 
1487  bool operator!=(const desc &other) const { return !operator==(other); }
1488  };
1489 
1490  memory() = default;
1491 
1492 #if DNNL_WITH_SYCL
1493  memory(const desc &md, const engine &aengine, void *ahandle)
1499 #ifdef DNNL_USE_DPCPP_USM
1500  : memory(with_sycl_tag {}, md, aengine, ahandle, true) {
1501  }
1502 #else
1503  : memory(with_sycl_tag {}, md, aengine, ahandle, false) {
1504  }
1505 #endif
1506 #else
1507  memory(const desc &md, const engine &aengine, void *ahandle) {
1513  dnnl_memory_t result;
1515  dnnl_memory_create(&result, &md.data, aengine.get(), ahandle),
1516  "could not create a memory");
1517  reset(result);
1518  }
1519 #endif
1520 
1521 #if DNNL_WITH_SYCL && defined(DNNL_USE_SYCL_BUFFERS)
1522  template <typename T, int ndims = 1>
1528  memory(const desc &md, const engine &aengine,
1529  cl::sycl::buffer<T, ndims> &buf)
1530  : memory(md, aengine, DNNL_MEMORY_NONE) {
1531  set_sycl_buffer(buf);
1532  }
1533 #endif
1534 
1539  memory(const desc &md, const engine &aengine)
1540  : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
1541 
1543  desc get_desc() const {
1544  const dnnl_memory_desc_t *cdesc;
1546  "could not get memory descriptor from a memory");
1547  return desc(*cdesc);
1548  }
1549 
1551  engine get_engine() const {
1552  dnnl_engine_t engine_q;
1554  "could not get engine from a memory");
1555  return engine(engine_q);
1556  }
1557 
1561  void *get_data_handle() const {
1562  void *handle;
1564  "could not get native handle");
1565  return handle;
1566  }
1567 
1568  void set_data_handle(void *handle) const {
1570  "could not set native handle");
1571  }
1572 
1588  template <typename T = void>
1589  T *map_data() const {
1590  void *mapped_ptr;
1591  error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
1592  "could not map the data");
1593  return static_cast<T *>(mapped_ptr);
1594  }
1595 
1604  void unmap_data(void *mapped_ptr) const {
1606  "could not unmap the data");
1607  }
1608 
1609 #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
1610  cl_mem get_ocl_mem_object() const {
1612  cl_mem mem_object;
1614  "could not get OpenCL memory object");
1615  return mem_object;
1616  }
1617 
1619  void set_ocl_mem_object(cl_mem mem_object) {
1621  "could not set OpenCL memory object");
1622  }
1623 #endif
1624 
1625 #if DNNL_WITH_SYCL && defined(DNNL_USE_SYCL_BUFFERS)
1626  template <typename T, int ndims = 1>
1631  cl::sycl::buffer<T, ndims> get_sycl_buffer(size_t *offset = nullptr) const {
1632  static_assert(ndims == 1, "only 1D buffers supported");
1633 
1634  void *handle_ptr;
1636  "could not get SYCL buffer object");
1637 
1638  // XXX: workaround for ComputeCpp
1639  // ComputeCpp fails to construct zero-range buffer
1640  if (!handle_ptr)
1641  return cl::sycl::buffer<T, ndims>(cl::sycl::range<1>(1));
1642 
1643  auto &buf_u8 = *static_cast<cl::sycl::buffer<uint8_t, 1> *>(handle_ptr);
1644  if (offset) *offset = 0;
1645  auto range = cl::sycl::range<1>(buf_u8.get_size() / sizeof(T));
1646  return buf_u8.reinterpret<T, 1>(range);
1647  }
1648 
1654  template <typename T, int ndims>
1655  void set_sycl_buffer(cl::sycl::buffer<T, ndims> &buf) {
1656  auto range = cl::sycl::range<1>(buf.get_size());
1657  auto buf_u8 = buf.template reinterpret<uint8_t, 1>(range);
1659  get(), static_cast<void *>(&buf_u8)),
1660  "could not set SYCL buffer object");
1661  }
1662 #endif
1663 
1664  // Must go away or be private:
1665  static dnnl_data_type_t convert_to_c(data_type adata_type) {
1666  return static_cast<dnnl_data_type_t>(adata_type);
1667  }
1668  static dnnl_format_tag_t convert_to_c(format_tag aformat) {
1669  return static_cast<dnnl_format_tag_t>(aformat);
1670  }
1671 
1672 private:
1673 #if DNNL_WITH_SYCL
1674  struct with_sycl_tag {};
1675 
1676  DNNL_API memory(with_sycl_tag, const desc &md, const engine &aengine,
1677  void *ahandle, bool is_usm);
1678 #endif
1679 };
1680 
1681 inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
1682  return a == memory::convert_to_c(b);
1683 }
1684 inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
1685  return !(a == b);
1686 }
1687 inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
1688  return b == a;
1689 }
1690 inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
1691  return !(a == b);
1692 }
1693 
1694 inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
1695  return a == memory::convert_to_c(b);
1696 }
1697 inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
1698  return !(a == b);
1699 }
1700 inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
1701  return b == a;
1702 }
1703 inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
1704  return !(a == b);
1705 }
1706 
1708 
1711 
1714 
1716 struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
1718 
1719  primitive_desc_base() = default;
1720 
1722  engine get_engine() const { return engine::query(*this); }
1723 
1725  const char *impl_info_str() const {
1726  const char *res;
1728  get(), dnnl_query_impl_info_str, 0, &res),
1729  "could not query implementation info string");
1730  return res;
1731  }
1732 
1734  memory::dim query_s64(query q) const {
1735  memory::dim res;
1737  get(), dnnl::convert_to_c(q), 0, &res);
1738  return status == dnnl_success ? res : 0;
1739  }
1740 
1742  memory::desc query_md(query what, int idx = 0) const {
1743  std::vector<query> valid_q {query::src_md, query::diff_src_md,
1746  if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
1747  [=](query q) { return what == q; }))
1748  throw error(dnnl_invalid_arguments, "invalid memory query");
1749 
1751  get(), dnnl::convert_to_c(what), idx);
1752  return memory::desc(*cdesc);
1753  }
1754 
1759  memory::desc scratchpad_desc() const {
1760  return query_md(query::scratchpad_md, 0);
1761  }
1762 
1764  engine scratchpad_engine() const {
1765  dnnl_engine_t engine_q;
1767  dnnl::convert_to_c(query::scratchpad_engine),
1768  0, &engine_q),
1769  "could not get scratchpad engine from a primitive_desc");
1770 
1771  return engine(engine_q);
1772  }
1773 
1775  primitive_attr get_primitive_attr() const {
1776  const_dnnl_primitive_attr_t const_cattr;
1778  "could not get attributes");
1779  dnnl_primitive_attr_t cattr;
1780  error::wrap_c_api(dnnl_primitive_attr_clone(&cattr, const_cattr),
1781  "could not clone attributes");
1782 
1783  return primitive_attr(cattr);
1784  }
1785 
1786 protected:
1787  void reset_with_clone(const_dnnl_primitive_desc_t pd) {
1788  dnnl_primitive_desc_t new_pd;
1790  "could not clone primitive descriptor");
1791  reset(new_pd);
1792  }
1793 
1794  primitive_desc_base(
1796  : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
1797 
1798  primitive_desc_base(dnnl_primitive_desc_t pd,
1800  : primitive_desc_base(pd, prim_kind, prop_kind, prop_kind) {}
1801 
1808  primitive_desc_base(dnnl_primitive_desc_t pd,
1809  dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
1810  dnnl::prop_kind prop_kind2) {
1811  // It is OK to pass an empty primitive descriptor
1812  if (pd == nullptr) return;
1813 
1814  dnnl_status_t rc;
1815 
1816  dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
1817  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
1818  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
1819 
1820  // Check that primitive kind matches
1821  dnnl_primitive_kind_t pd_kind;
1823  pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
1824  error::wrap_c_api(rc,
1825  "could not get primitive kind from the primitive descriptor");
1826  if (pd_kind != c_prim_kind)
1827  throw error(dnnl_invalid_arguments,
1828  "primitive descriptor operation kind mismatch");
1829 
1830  // Check that propagation kind matches
1831  dnnl_prop_kind_t pd_prop_kind;
1833  pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
1834 
1835  // Something went wrong
1836  if (rc != dnnl_success && rc != dnnl_unimplemented)
1837  throw error(dnnl_invalid_arguments,
1838  "could not get propagation kind "
1839  "from the primitive descriptor");
1840 
1841  // Everything is fine
1842  if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
1843  || (rc == dnnl_success
1844  && (pd_prop_kind == c_prop_kind1
1845  || pd_prop_kind == c_prop_kind2))) {
1846  reset_with_clone(pd);
1847  return;
1848  }
1849 
1850  // We could get the propagation kind but there is a mismatch
1851  throw error(dnnl_invalid_arguments,
1852  "primitive descriptor propagation kind mismatch");
1853  }
1854 };
1855 
1858 
1865 
1869 struct reorder : public primitive {
1870  struct primitive_desc : public primitive_desc_base {
1871  using primitive_desc_base::primitive_desc_base;
1872 
1873  primitive_desc() = default;
1874 
1875  primitive_desc(const engine &src_engine, const memory::desc &src_md,
1876  const engine &dst_engine, const memory::desc &dst_md,
1877  const primitive_attr &aattr = primitive_attr()) {
1878  dnnl_primitive_desc_t result;
1881  src_engine.get(), &dst_md.data, dst_engine.get(),
1882  aattr.get()),
1883  "could not create a reorder primitive descriptor");
1884  reset(result);
1885  }
1886 
1887  primitive_desc(const memory &src, const memory &dst,
1888  const primitive_attr &aattr = primitive_attr()) {
1889  dnnl_primitive_desc_t result;
1890  auto src_md = src.get_desc();
1891  auto dst_md = dst.get_desc();
1894  src.get_engine().get(), &dst_md.data,
1895  dst.get_engine().get(), aattr.get()),
1896  "could not create a reorder primitive descriptor");
1897  reset(result);
1898  }
1899 
1902  primitive_desc(dnnl_primitive_desc_t pd)
1903  : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
1904 
1905  engine get_src_engine() const {
1906  return engine::query(*this, dnnl::query::reorder_src_engine);
1907  }
1908 
1909  engine get_dst_engine() const {
1910  return engine::query(*this, dnnl::query::reorder_dst_engine);
1911  }
1912  };
1913 
1914  reorder() = default;
1915 
1916  reorder(const primitive_desc &pd) : primitive(pd.get()) {}
1917 
1918  reorder(const memory &src, const memory &dst)
1919  : primitive(primitive_desc(src, dst).get()) {}
1920 
1921  using primitive::execute;
1922 
1923  void execute(stream astream, memory &src, memory &dst) {
1924  primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
1925  }
1926 
1927 #ifdef DNNL_SYCL_DPCPP
1928  using primitive::execute_sycl;
1929 
1930  cl::sycl::event execute_sycl(stream &astream, memory &src, memory &dst,
1931  const std::vector<cl::sycl::event> &deps = {}) const {
1932  return primitive::execute_sycl(astream,
1933  {{DNNL_ARG_FROM, src},
1934  { DNNL_ARG_TO,
1935  dst }},
1936  deps);
1937  }
1938 #endif
1939 };
1940 
1942 
1949 
1951 inline std::vector<dnnl_memory_desc_t> convert_to_c(
1952  const std::vector<memory::desc> &mems) {
1953  std::vector<dnnl_memory_desc_t> c_api_mems;
1954  c_api_mems.reserve(mems.size());
1955  for (const auto &s : mems)
1956  c_api_mems.push_back(s.data);
1957  return c_api_mems;
1958 }
1960 
1968 struct concat : public primitive {
1969  struct primitive_desc : public primitive_desc_base {
1970  using primitive_desc_base::primitive_desc_base;
1971 
1972  primitive_desc(const memory::desc &dst, int concat_dimension,
1973  const std::vector<memory::desc> &srcs, const engine &aengine,
1974  const primitive_attr &aattr = primitive_attr()) {
1975  auto c_api_srcs = convert_to_c(srcs);
1976 
1977  dnnl_primitive_desc_t result;
1979  dnnl_concat_primitive_desc_create(&result, &dst.data,
1980  (int)c_api_srcs.size(), concat_dimension,
1981  &c_api_srcs[0], aattr.get(), aengine.get()),
1982  "could not create a concat primitive descriptor");
1983  reset(result);
1984  }
1985 
1986  primitive_desc(int concat_dimension,
1987  const std::vector<memory::desc> &srcs, const engine &aengine,
1988  const primitive_attr &aattr = primitive_attr()) {
1989  auto c_api_srcs = convert_to_c(srcs);
1990 
1991  dnnl_primitive_desc_t result;
1993  dnnl_concat_primitive_desc_create(&result, nullptr,
1994  (int)c_api_srcs.size(), concat_dimension,
1995  &c_api_srcs[0], aattr.get(), aengine.get()),
1996  "could not create a concat primitive descriptor");
1997  reset(result);
1998  }
1999 
2002  primitive_desc(dnnl_primitive_desc_t pd)
2003  : primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
2004 
2006  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
2007  };
2008 
2009  concat() = default;
2010 
2011  concat(const primitive_desc &pd) : primitive(pd.get()) {}
2012 };
2013 
2015 
2022 
2028 struct sum : public primitive {
2029  struct primitive_desc : public primitive_desc_base {
2030  using primitive_desc_base::primitive_desc_base;
2031 
2032  primitive_desc() = default;
2033 
2034  primitive_desc(const memory::desc &dst,
2035  const std::vector<float> &scales,
2036  const std::vector<memory::desc> &srcs, const engine &aengine,
2037  const primitive_attr &aattr = primitive_attr()) {
2038  error::wrap_c_api(scales.size() == srcs.size()
2039  ? dnnl_success
2041  "number of scales not equal to number of srcs");
2042 
2043  auto c_api_srcs = convert_to_c(srcs);
2044 
2045  dnnl_primitive_desc_t result;
2047  dnnl_sum_primitive_desc_create(&result, &dst.data,
2048  (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0],
2049  aattr.get(), aengine.get()),
2050  "could not create a sum primitive descriptor");
2051  reset(result);
2052  }
2053 
2054  primitive_desc(const std::vector<float> &scales,
2055  const std::vector<memory::desc> &srcs, const engine &aengine,
2056  const primitive_attr &aattr = primitive_attr()) {
2057  error::wrap_c_api(scales.size() == srcs.size()
2058  ? dnnl_success
2060  "number of scales not equal to number of srcs");
2061 
2062  auto c_api_srcs = convert_to_c(srcs);
2063  dnnl_primitive_desc_t result;
2065  dnnl_sum_primitive_desc_create(&result, nullptr,
2066  (int)c_api_srcs.size(), &scales[0], &c_api_srcs[0],
2067  aattr.get(), aengine.get()),
2068  "could not create a sum primitive descriptor");
2069  reset(result);
2070  }
2071 
2074  primitive_desc(dnnl_primitive_desc_t pd)
2075  : primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
2076 
2078  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
2079  };
2080 
2081  sum() = default;
2082 
2083  sum(const primitive_desc &pd) : primitive(pd.get()) {}
2084 };
2085 
2087 
2089 
2092 
2095 
2098 struct primitive_desc : public primitive_desc_base {
2099  using primitive_desc_base::primitive_desc_base;
2100 
2101  primitive_desc() = default;
2102 
2108  primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr,
2109  const engine &e, const_dnnl_primitive_desc_t hint_fwd_pd,
2110  bool allow_empty = false)
2111  : allow_empty(allow_empty) {
2112  dnnl_primitive_desc_iterator_t iterator = nullptr;
2114  desc, attr ? attr->get() : nullptr, e.get(), hint_fwd_pd);
2115  if (!allow_empty)
2117  status, "could not create a primitive descriptor iterator");
2118  pd_iterator.reset(iterator);
2119  fetch_impl();
2120  }
2121 
2128  bool next_impl() {
2129  dnnl_status_t status
2130  = dnnl_primitive_desc_iterator_next(pd_iterator.get());
2131  if (status == dnnl_iterator_ends) return false;
2132  error::wrap_c_api(status, "primitive descriptor iterator next failed");
2133 
2134  fetch_impl();
2135  return true;
2136  }
2137 
2138 private:
2139  bool allow_empty = false;
2140  handle<dnnl_primitive_desc_iterator_t> pd_iterator;
2141  void fetch_impl() {
2143  pd_iterator.get(allow_empty));
2144  error::wrap_c_api(pd != nullptr || allow_empty ? dnnl_success
2146  "could not fetch a primitive descriptor from the iterator");
2147  reset(pd);
2148  }
2149 };
2150 
2152 
2160 
2165 struct convolution_forward : public primitive {
2166 
2168  struct desc {
2170 
2179  desc(prop_kind aprop_kind, algorithm aalgorithm,
2180  const memory::desc &src_desc, const memory::desc &weights_desc,
2181  const memory::desc &bias_desc, const memory::desc &dst_desc,
2182  const memory::dims &strides, const memory::dims &padding_l,
2183  const memory::dims &padding_r) {
2184  memory::validate_dims(strides);
2185  memory::validate_dims(padding_l);
2186  memory::validate_dims(padding_r);
2189  dnnl::convert_to_c(aprop_kind),
2190  convert_to_c(aalgorithm), &src_desc.data,
2191  &weights_desc.data, &bias_desc.data, &dst_desc.data,
2192  &strides[0], &padding_l[0], &padding_r[0]),
2193  "could not create a convolution forward descriptor");
2194  }
2195 
2204  desc(prop_kind aprop_kind, algorithm aalgorithm,
2205  const memory::desc &src_desc, const memory::desc &weights_desc,
2206  const memory::desc &dst_desc, const memory::dims &strides,
2207  const memory::dims &padding_l, const memory::dims &padding_r) {
2208  memory::validate_dims(strides);
2209  memory::validate_dims(padding_l);
2210  memory::validate_dims(padding_r);
2213  dnnl::convert_to_c(aprop_kind),
2214  convert_to_c(aalgorithm), &src_desc.data,
2215  &weights_desc.data, nullptr, &dst_desc.data,
2216  &strides[0], &padding_l[0], &padding_r[0]),
2217  "could not create a convolution forward descriptor");
2218  }
2219 
2228  desc(prop_kind aprop_kind, algorithm aalgorithm,
2229  const memory::desc &src_desc, const memory::desc &weights_desc,
2230  const memory::desc &bias_desc, const memory::desc &dst_desc,
2231  const memory::dims &strides, const memory::dims &dilates,
2232  const memory::dims &padding_l, const memory::dims &padding_r) {
2233  memory::validate_dims(strides);
2234  memory::validate_dims(dilates);
2235  memory::validate_dims(padding_l);
2236  memory::validate_dims(padding_r);
2238  dnnl::convert_to_c(aprop_kind),
2239  convert_to_c(aalgorithm), &src_desc.data,
2240  &weights_desc.data, &bias_desc.data,
2241  &dst_desc.data, &strides[0], &dilates[0],
2242  &padding_l[0], &padding_r[0]),
2243  "could not create a dilated convolution forward "
2244  "descriptor");
2245  }
2246 
2255  desc(prop_kind aprop_kind, algorithm aalgorithm,
2256  const memory::desc &src_desc, const memory::desc &weights_desc,
2257  const memory::desc &dst_desc, const memory::dims &strides,
2258  const memory::dims &dilates, const memory::dims &padding_l,
2259  const memory::dims &padding_r) {
2260  memory::validate_dims(strides);
2261  memory::validate_dims(dilates);
2262  memory::validate_dims(padding_l);
2263  memory::validate_dims(padding_r);
2265  dnnl::convert_to_c(aprop_kind),
2266  convert_to_c(aalgorithm), &src_desc.data,
2267  &weights_desc.data, nullptr,
2268  &dst_desc.data, &strides[0], &dilates[0],
2269  &padding_l[0], &padding_r[0]),
2270  "could not create a dilated convolution forward "
2271  "descriptor");
2272  }
2273  };
2274 
2276  struct primitive_desc : public dnnl::primitive_desc {
2277  primitive_desc() = default;
2278 
2281  primitive_desc(
2282  const desc &desc, const engine &e, bool allow_empty = false)
2283  : dnnl::primitive_desc(
2284  &desc.data, nullptr, e, nullptr, allow_empty) {}
2285 
2288  primitive_desc(const desc &desc, const primitive_attr &attr,
2289  const engine &e, bool allow_empty = false)
2290  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
2291  }
2292 
2295  primitive_desc(dnnl_primitive_desc_t pd)
2296  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2297  dnnl::prop_kind::forward_training,
2298  dnnl::prop_kind::forward_inference) {}
2299 
2301  memory::desc src_desc() const { return query_md(query::src_md, 0); }
2302 
2304  memory::desc weights_desc() const {
2305  return query_md(query::weights_md, 0);
2306  }
2307 
2312  memory::desc bias_desc() const {
2313  return query_md(query::weights_md, 1);
2314  }
2315 
2317  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
2318  };
2319 
2320  convolution_forward() = default;
2321 
2324  convolution_forward(const primitive_desc &pd) : primitive(pd) {}
2325 };
2326 
2331 struct convolution_backward_data : public primitive {
2332 
2334  struct desc {
2336 
2343  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2344  const memory::desc &weights_desc,
2345  const memory::desc &diff_dst_desc, const memory::dims &strides,
2346  const memory::dims &padding_l, const memory::dims &padding_r) {
2347  memory::validate_dims(strides);
2348  memory::validate_dims(padding_l);
2349  memory::validate_dims(padding_r);
2352  convert_to_c(aalgorithm), &diff_src_desc.data,
2353  &weights_desc.data, &diff_dst_desc.data,
2354  &strides[0], &padding_l[0], &padding_r[0]),
2355  "could not create a convolution backward data descriptor");
2356  }
2357 
2364  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2365  const memory::desc &weights_desc,
2366  const memory::desc &diff_dst_desc, const memory::dims &strides,
2367  const memory::dims &dilates, const memory::dims &padding_l,
2368  const memory::dims &padding_r) {
2369  memory::validate_dims(strides);
2370  memory::validate_dims(dilates);
2371  memory::validate_dims(padding_l);
2372  memory::validate_dims(padding_r);
2375  convert_to_c(aalgorithm), &diff_src_desc.data,
2376  &weights_desc.data, &diff_dst_desc.data,
2377  &strides[0], &dilates[0], &padding_l[0],
2378  &padding_r[0]),
2379  "could not create a convolution backward data descriptor");
2380  }
2381  };
2382 
2384  struct primitive_desc : public dnnl::primitive_desc {
2385  primitive_desc() = default;
2386 
2389  primitive_desc(const desc &desc, const engine &e,
2390  const convolution_forward::primitive_desc &hint_fwd_pd,
2391  bool allow_empty = false)
2392  : dnnl::primitive_desc(
2393  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2394 
2397  primitive_desc(const desc &desc, const primitive_attr &attr,
2398  const engine &e,
2399  const convolution_forward::primitive_desc &hint_fwd_pd,
2400  bool allow_empty = false)
2401  : dnnl::primitive_desc(
2402  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2403 
2406  primitive_desc(dnnl_primitive_desc_t pd)
2407  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2408  dnnl::prop_kind::backward_data) {}
2409 
2411  memory::desc diff_src_desc() const {
2412  return query_md(query::diff_src_md, 0);
2413  }
2414 
2416  memory::desc weights_desc() const {
2417  return query_md(query::weights_md, 0);
2418  }
2419 
2421  memory::desc diff_dst_desc() const {
2422  return query_md(query::diff_dst_md, 0);
2423  }
2424  };
2425 
2426  convolution_backward_data() = default;
2427 
2430  convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
2431 };
2432 
2437 struct convolution_backward_weights : public primitive {
2438 
2440  struct desc {
2442 
2449  desc(algorithm aalgorithm, const memory::desc &src_desc,
2450  const memory::desc &diff_weights_desc,
2451  const memory::desc &diff_bias_desc,
2452  const memory::desc &diff_dst_desc, const memory::dims &strides,
2453  const memory::dims &padding_l, const memory::dims &padding_r) {
2454  memory::validate_dims(strides);
2455  memory::validate_dims(padding_l);
2456  memory::validate_dims(padding_r);
2459  convert_to_c(aalgorithm), &src_desc.data,
2460  &diff_weights_desc.data, &diff_bias_desc.data,
2461  &diff_dst_desc.data, &strides[0], &padding_l[0],
2462  &padding_r[0]),
2463  "could not create a convolution backward weights "
2464  "descriptor");
2465  }
2466 
2473  desc(algorithm aalgorithm, const memory::desc &src_desc,
2474  const memory::desc &diff_weights_desc,
2475  const memory::desc &diff_dst_desc, const memory::dims &strides,
2476  const memory::dims &padding_l, const memory::dims &padding_r) {
2477  memory::validate_dims(strides);
2478  memory::validate_dims(padding_l);
2479  memory::validate_dims(padding_r);
2481  convert_to_c(aalgorithm), &src_desc.data,
2482  &diff_weights_desc.data, nullptr,
2483  &diff_dst_desc.data, &strides[0],
2484  &padding_l[0], &padding_r[0]),
2485  "could not create a convolution backward weights "
2486  "descriptor");
2487  }
2488 
2495  desc(algorithm aalgorithm, const memory::desc &src_desc,
2496  const memory::desc &diff_weights_desc,
2497  const memory::desc &diff_bias_desc,
2498  const memory::desc &diff_dst_desc, const memory::dims &strides,
2499  const memory::dims &dilates, const memory::dims &padding_l,
2500  const memory::dims &padding_r) {
2501  memory::validate_dims(strides);
2502  memory::validate_dims(dilates);
2503  memory::validate_dims(padding_l);
2504  memory::validate_dims(padding_r);
2507  convert_to_c(aalgorithm), &src_desc.data,
2508  &diff_weights_desc.data, &diff_bias_desc.data,
2509  &diff_dst_desc.data, &strides[0], &dilates[0],
2510  &padding_l[0], &padding_r[0]),
2511  "could not create a convolution backward weights "
2512  "descriptor");
2513  }
2514 
2521  desc(algorithm aalgorithm, const memory::desc &src_desc,
2522  const memory::desc &diff_weights_desc,
2523  const memory::desc &diff_dst_desc, const memory::dims &strides,
2524  const memory::dims &dilates, const memory::dims &padding_l,
2525  const memory::dims &padding_r) {
2526  memory::validate_dims(strides);
2527  memory::validate_dims(dilates);
2528  memory::validate_dims(padding_l);
2529  memory::validate_dims(padding_r);
2532  convert_to_c(aalgorithm), &src_desc.data,
2533  &diff_weights_desc.data, nullptr,
2534  &diff_dst_desc.data, &strides[0], &dilates[0],
2535  &padding_l[0], &padding_r[0]),
2536  "could not create a convolution backward weights "
2537  "descriptor");
2538  }
2539  };
2540 
2542  struct primitive_desc : public dnnl::primitive_desc {
2543  primitive_desc() = default;
2544 
2546  primitive_desc(const desc &desc, const engine &e,
2547  const convolution_forward::primitive_desc &hint_fwd_pd,
2548  bool allow_empty = false)
2549  : dnnl::primitive_desc(
2550  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2551 
2554  primitive_desc(const desc &desc, const primitive_attr &attr,
2555  const engine &e,
2556  const convolution_forward::primitive_desc &hint_fwd_pd,
2557  bool allow_empty = false)
2558  : dnnl::primitive_desc(
2559  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2560 
2563  primitive_desc(dnnl_primitive_desc_t pd)
2564  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
2565  dnnl::prop_kind::backward_weights) {}
2566 
2568  memory::desc src_desc() const { return query_md(query::src_md, 0); }
2569 
2571  memory::desc diff_weights_desc() const {
2572  return query_md(query::diff_weights_md, 0);
2573  }
2574 
2576  memory::desc diff_bias_desc() const {
2577  return query_md(query::diff_weights_md, 1);
2578  }
2579 
2581  memory::desc diff_dst_desc() const {
2582  return query_md(query::diff_dst_md, 0);
2583  }
2584  };
2585 
2586  convolution_backward_weights() = default;
2587 
2590  convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
2591 };
2592 
2594 //
2600 
2605 struct deconvolution_forward : public primitive {
2606 
2608  struct desc {
2610 
2619  desc(prop_kind aprop_kind, algorithm aalgorithm,
2620  const memory::desc &src_desc, const memory::desc &weights_desc,
2621  const memory::desc &bias_desc, const memory::desc &dst_desc,
2622  const memory::dims &strides, const memory::dims &padding_l,
2623  const memory::dims &padding_r) {
2624  memory::validate_dims(strides);
2625  memory::validate_dims(padding_l);
2626  memory::validate_dims(padding_r);
2629  dnnl::convert_to_c(aprop_kind),
2630  convert_to_c(aalgorithm), &src_desc.data,
2631  &weights_desc.data, &bias_desc.data, &dst_desc.data,
2632  &strides[0], &padding_l[0], &padding_r[0]),
2633  "could not create a deconvolution forward descriptor");
2634  }
2635 
2644  desc(prop_kind aprop_kind, algorithm aalgorithm,
2645  const memory::desc &src_desc, const memory::desc &weights_desc,
2646  const memory::desc &dst_desc, const memory::dims &strides,
2647  const memory::dims &padding_l, const memory::dims &padding_r) {
2648  memory::validate_dims(strides);
2649  memory::validate_dims(padding_l);
2650  memory::validate_dims(padding_r);
2653  dnnl::convert_to_c(aprop_kind),
2654  convert_to_c(aalgorithm), &src_desc.data,
2655  &weights_desc.data, nullptr, &dst_desc.data,
2656  &strides[0], &padding_l[0], &padding_r[0]),
2657  "could not create a deconvolution forward descriptor");
2658  }
2659 
2668  desc(prop_kind aprop_kind, algorithm aalgorithm,
2669  const memory::desc &src_desc, const memory::desc &weights_desc,
2670  const memory::desc &bias_desc, const memory::desc &dst_desc,
2671  const memory::dims &strides, const memory::dims &dilates,
2672  const memory::dims &padding_l, const memory::dims &padding_r) {
2673  memory::validate_dims(strides);
2674  memory::validate_dims(dilates);
2675  memory::validate_dims(padding_l);
2676  memory::validate_dims(padding_r);
2678  &data, dnnl::convert_to_c(aprop_kind),
2679  convert_to_c(aalgorithm), &src_desc.data,
2680  &weights_desc.data, &bias_desc.data,
2681  &dst_desc.data, &strides[0], &dilates[0],
2682  &padding_l[0], &padding_r[0]),
2683  "could not create a dilated deconvolution forward "
2684  "descriptor");
2685  }
2686 
2695  desc(prop_kind aprop_kind, algorithm aalgorithm,
2696  const memory::desc &src_desc, const memory::desc &weights_desc,
2697  const memory::desc &dst_desc, const memory::dims &strides,
2698  const memory::dims &dilates, const memory::dims &padding_l,
2699  const memory::dims &padding_r) {
2700  memory::validate_dims(strides);
2701  memory::validate_dims(dilates);
2702  memory::validate_dims(padding_l);
2703  memory::validate_dims(padding_r);
2705  &data, dnnl::convert_to_c(aprop_kind),
2706  convert_to_c(aalgorithm), &src_desc.data,
2707  &weights_desc.data, nullptr,
2708  &dst_desc.data, &strides[0], &dilates[0],
2709  &padding_l[0], &padding_r[0]),
2710  "could not create a dilated deconvolution forward "
2711  "descriptor");
2712  }
2713  };
2714 
2716  struct primitive_desc : public dnnl::primitive_desc {
2717  primitive_desc() = default;
2718 
2721  primitive_desc(
2722  const desc &desc, const engine &e, bool allow_empty = false)
2723  : dnnl::primitive_desc(
2724  &desc.data, nullptr, e, nullptr, allow_empty) {}
2725 
2728  primitive_desc(const desc &desc, const primitive_attr &attr,
2729  const engine &e, bool allow_empty = false)
2730  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
2731  }
2732 
2735  primitive_desc(dnnl_primitive_desc_t pd)
2736  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
2737  dnnl::prop_kind::forward_training,
2738  dnnl::prop_kind::forward_inference) {}
2739 
2741  memory::desc src_desc() const { return query_md(query::src_md, 0); }
2742 
2744  memory::desc weights_desc() const {
2745  return query_md(query::weights_md, 0);
2746  }
2747 
2752  memory::desc bias_desc() const {
2753  return query_md(query::weights_md, 1);
2754  }
2755 
2757  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
2758  };
2759 
2760  deconvolution_forward() = default;
2761 
2764  deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
2765 };
2766 
2771 struct deconvolution_backward_data : public primitive {
2772 
2774  struct desc {
2776 
2783  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2784  const memory::desc &weights_desc,
2785  const memory::desc &diff_dst_desc, const memory::dims &strides,
2786  const memory::dims &padding_l, const memory::dims &padding_r) {
2787  memory::validate_dims(strides);
2788  memory::validate_dims(padding_l);
2789  memory::validate_dims(padding_r);
2792  convert_to_c(aalgorithm), &diff_src_desc.data,
2793  &weights_desc.data, &diff_dst_desc.data,
2794  &strides[0], &padding_l[0], &padding_r[0]),
2795  "could not create a deconvolution backward data "
2796  "descriptor");
2797  }
2798 
2805  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
2806  const memory::desc &weights_desc,
2807  const memory::desc &diff_dst_desc, const memory::dims &strides,
2808  const memory::dims &dilates, const memory::dims &padding_l,
2809  const memory::dims &padding_r) {
2810  memory::validate_dims(strides);
2811  memory::validate_dims(dilates);
2812  memory::validate_dims(padding_l);
2813  memory::validate_dims(padding_r);
2816  convert_to_c(aalgorithm), &diff_src_desc.data,
2817  &weights_desc.data, &diff_dst_desc.data,
2818  &strides[0], &dilates[0], &padding_l[0],
2819  &padding_r[0]),
2820  "could not create a dilated deconvolution backward data "
2821  "descriptor");
2822  }
2823  };
2824 
2826  struct primitive_desc : public dnnl::primitive_desc {
2827  primitive_desc() = default;
2828 
2831  primitive_desc(const desc &desc, const engine &e,
2832  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2833  bool allow_empty = false)
2834  : dnnl::primitive_desc(
2835  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2836 
2839  primitive_desc(const desc &desc, const primitive_attr &attr,
2840  const engine &e,
2841  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2842  bool allow_empty = false)
2843  : dnnl::primitive_desc(
2844  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
2845 
2848  primitive_desc(dnnl_primitive_desc_t pd)
2849  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
2850  dnnl::prop_kind::backward_data) {}
2851 
2853  memory::desc diff_src_desc() const {
2854  return query_md(query::diff_src_md, 0);
2855  }
2856 
2858  memory::desc weights_desc() const {
2859  return query_md(query::weights_md, 0);
2860  }
2861 
2863  memory::desc diff_dst_desc() const {
2864  return query_md(query::diff_dst_md, 0);
2865  }
2866  };
2867 
2868  deconvolution_backward_data() = default;
2869 
2872  deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
2873 };
2874 
2879 struct deconvolution_backward_weights : public primitive {
2880 
2882  struct desc {
2884 
2891  desc(algorithm aalgorithm, const memory::desc &src_desc,
2892  const memory::desc &diff_weights_desc,
2893  const memory::desc &diff_bias_desc,
2894  const memory::desc &diff_dst_desc, const memory::dims &strides,
2895  const memory::dims &padding_l, const memory::dims &padding_r) {
2896  memory::validate_dims(strides);
2897  memory::validate_dims(padding_l);
2898  memory::validate_dims(padding_r);
2901  convert_to_c(aalgorithm), &src_desc.data,
2902  &diff_weights_desc.data, &diff_bias_desc.data,
2903  &diff_dst_desc.data, &strides[0], &padding_l[0],
2904  &padding_r[0]),
2905  "could not create a deconvolution backward weights "
2906  "descriptor");
2907  }
2908 
2915  desc(algorithm aalgorithm, const memory::desc &src_desc,
2916  const memory::desc &diff_weights_desc,
2917  const memory::desc &diff_dst_desc, const memory::dims &strides,
2918  const memory::dims &padding_l, const memory::dims &padding_r) {
2919  memory::validate_dims(strides);
2920  memory::validate_dims(padding_l);
2921  memory::validate_dims(padding_r);
2923  &data, convert_to_c(aalgorithm),
2924  &src_desc.data, &diff_weights_desc.data,
2925  nullptr, &diff_dst_desc.data, &strides[0],
2926  &padding_l[0], &padding_r[0]),
2927  "could not create a deconvolution backward weights "
2928  "descriptor");
2929  }
2930 
2937  desc(algorithm aalgorithm, const memory::desc &src_desc,
2938  const memory::desc &diff_weights_desc,
2939  const memory::desc &diff_bias_desc,
2940  const memory::desc &diff_dst_desc, const memory::dims &strides,
2941  const memory::dims &dilates, const memory::dims &padding_l,
2942  const memory::dims &padding_r) {
2943  memory::validate_dims(strides);
2944  memory::validate_dims(dilates);
2945  memory::validate_dims(padding_l);
2946  memory::validate_dims(padding_r);
2949  convert_to_c(aalgorithm), &src_desc.data,
2950  &diff_weights_desc.data, &diff_bias_desc.data,
2951  &diff_dst_desc.data, &strides[0], &dilates[0],
2952  &padding_l[0], &padding_r[0]),
2953  "could not create a dilated deconvolution backward "
2954  "weights descriptor");
2955  }
2956 
2963  desc(algorithm aalgorithm, const memory::desc &src_desc,
2964  const memory::desc &diff_weights_desc,
2965  const memory::desc &diff_dst_desc, const memory::dims &strides,
2966  const memory::dims &dilates, const memory::dims &padding_l,
2967  const memory::dims &padding_r) {
2968  memory::validate_dims(strides);
2969  memory::validate_dims(dilates);
2970  memory::validate_dims(padding_l);
2971  memory::validate_dims(padding_r);
2974  convert_to_c(aalgorithm), &src_desc.data,
2975  &diff_weights_desc.data, nullptr,
2976  &diff_dst_desc.data, &strides[0], &dilates[0],
2977  &padding_l[0], &padding_r[0]),
2978  "could not create a dilated deconvolution backward weights "
2979  "descriptor");
2980  }
2981  };
2982 
2984  struct primitive_desc : public dnnl::primitive_desc {
2985  primitive_desc() = default;
2986 
2988  primitive_desc(const desc &desc, const engine &e,
2989  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2990  bool allow_empty = false)
2991  : dnnl::primitive_desc(
2992  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
2993 
2996  primitive_desc(const desc &desc, const primitive_attr &attr,
2997  const engine &e,
2998  const deconvolution_forward::primitive_desc &hint_fwd_pd,
2999  bool allow_empty = false)
3000  : dnnl::primitive_desc(
3001  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3002 
3005  primitive_desc(dnnl_primitive_desc_t pd)
3006  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
3007  dnnl::prop_kind::backward_weights) {}
3008 
3010  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3011 
3013  memory::desc diff_weights_desc() const {
3014  return query_md(query::diff_weights_md, 0);
3015  }
3016 
3018  memory::desc diff_bias_desc() const {
3019  return query_md(query::diff_weights_md, 1);
3020  }
3021 
3023  memory::desc diff_dst_desc() const {
3024  return query_md(query::diff_dst_md, 0);
3025  }
3026  };
3027 
3028  deconvolution_backward_weights() = default;
3029 
3032  deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
3033 };
3034 
3036 
3044 
3047 struct lrn_forward : public primitive {
3048 
3050  struct desc {
3051  dnnl_lrn_desc_t data;
3052 
3058  desc(prop_kind aprop_kind, algorithm aalgorithm,
3059  const memory::desc &src_desc, memory::dim local_size,
3060  float alpha, float beta, float k = 1.f) {
3062  dnnl::convert_to_c(aprop_kind),
3063  convert_to_c(aalgorithm), &src_desc.data,
3064  local_size, alpha, beta, k),
3065  "could not create a lrn forward descriptor");
3066  }
3067  };
3068 
3071  struct primitive_desc : public dnnl::primitive_desc {
3072  primitive_desc() = default;
3073 
3074  primitive_desc(
3075  const desc &desc, const engine &e, bool allow_empty = false)
3076  : dnnl::primitive_desc(
3077  &desc.data, nullptr, e, nullptr, allow_empty) {}
3078 
3079  primitive_desc(const desc &desc, const primitive_attr &attr,
3080  const engine &e, bool allow_empty = false)
3081  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3082  }
3083 
3087  primitive_desc(dnnl_primitive_desc_t pd)
3088  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
3089  dnnl::prop_kind::forward_training,
3090  dnnl::prop_kind::forward_inference) {}
3091 
3093  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3094 
3096  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3097 
3101  memory::desc workspace_desc() const {
3102  return query_md(query::workspace_md, 0);
3103  }
3104  };
3105 
3106  lrn_forward() = default;
3107 
3108  lrn_forward(const primitive_desc &pd) : primitive(pd) {}
3109 };
3110 
3113 struct lrn_backward : public primitive {
3114 
3116  struct desc {
3117  dnnl_lrn_desc_t data;
3118 
3123  desc(algorithm aalgorithm, const memory::desc &data_desc,
3124  const memory::desc &diff_data_desc, memory::dim local_size,
3125  float alpha, float beta, float k = 1.f) {
3127  dnnl_lrn_backward_desc_init(&data, convert_to_c(aalgorithm),
3128  &diff_data_desc.data, &data_desc.data, local_size,
3129  alpha, beta, k),
3130  "could not create a lrn backward descriptor");
3131  }
3132  };
3133 
3136  struct primitive_desc : public dnnl::primitive_desc {
3137  primitive_desc() = default;
3138 
3139  primitive_desc(const desc &desc, const engine &e,
3140  const lrn_forward::primitive_desc &hint_fwd_pd,
3141  bool allow_empty = false)
3142  : dnnl::primitive_desc(
3143  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3144 
3145  primitive_desc(const desc &desc, const primitive_attr &attr,
3146  const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd,
3147  bool allow_empty = false)
3148  : dnnl::primitive_desc(
3149  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3150 
3154  primitive_desc(dnnl_primitive_desc_t pd)
3155  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
3156  dnnl::prop_kind::backward_data) {}
3157 
3159  memory::desc diff_src_desc() const {
3160  return query_md(query::diff_src_md, 0);
3161  }
3162 
3164  memory::desc diff_dst_desc() const {
3165  return query_md(query::diff_dst_md, 0);
3166  }
3167 
3171  memory::desc workspace_desc() const {
3172  return query_md(query::workspace_md, 0);
3173  }
3174  };
3175 
3176  lrn_backward() = default;
3177 
3178  lrn_backward(const primitive_desc &pd) : primitive(pd) {}
3179 };
3180 
3182 
3189 
3192 struct pooling_forward : public primitive {
3193 
3195  struct desc {
3196  dnnl_pooling_desc_t data;
3197 
3203  desc(prop_kind aprop_kind, algorithm aalgorithm,
3204  const memory::desc &src_desc, const memory::desc &dst_desc,
3205  const memory::dims &strides, const memory::dims &kernel,
3206  const memory::dims &padding_l, const memory::dims &padding_r) {
3207  memory::validate_dims(strides);
3208  memory::validate_dims(kernel);
3209  memory::validate_dims(padding_l);
3210  memory::validate_dims(padding_r);
3212  dnnl::convert_to_c(aprop_kind),
3213  convert_to_c(aalgorithm), &src_desc.data,
3214  &dst_desc.data, &strides[0], &kernel[0],
3215  &padding_l[0], &padding_r[0]),
3216  "could not init a forward pooling descriptor");
3217  }
3218  };
3219 
3221  struct primitive_desc : public dnnl::primitive_desc {
3222  primitive_desc() = default;
3223 
3224  primitive_desc(
3225  const desc &desc, const engine &e, bool allow_empty = false)
3226  : dnnl::primitive_desc(
3227  &desc.data, nullptr, e, nullptr, allow_empty) {}
3228 
3229  primitive_desc(const desc &desc, const primitive_attr &attr,
3230  const engine &e, bool allow_empty = false)
3231  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3232  }
3233 
3236  primitive_desc(dnnl_primitive_desc_t pd)
3237  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
3238  dnnl::prop_kind::forward_training,
3239  dnnl::prop_kind::forward_inference) {}
3240 
3242  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3243 
3245  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3246 
3250  memory::desc workspace_desc() const {
3251  return query_md(query::workspace_md, 0);
3252  }
3253  };
3254 
3255  pooling_forward() = default;
3256 
3257  pooling_forward(const primitive_desc &pd) : primitive(pd) {}
3258 };
3259 
3260 struct pooling_backward : public primitive {
3261 
3263  struct desc {
3264  dnnl_pooling_desc_t data;
3265 
3269  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
3270  const memory::desc &diff_dst_desc, const memory::dims &strides,
3271  const memory::dims &kernel, const memory::dims &padding_l,
3272  const memory::dims &padding_r) {
3273  memory::validate_dims(strides);
3274  memory::validate_dims(kernel);
3275  memory::validate_dims(padding_l);
3276  memory::validate_dims(padding_r);
3279  convert_to_c(aalgorithm), &diff_src_desc.data,
3280  &diff_dst_desc.data, &strides[0], &kernel[0],
3281  &padding_l[0], &padding_r[0]),
3282  "could not init a backward pooling descriptor");
3283  }
3284  };
3285 
3287  struct primitive_desc : public dnnl::primitive_desc {
3288  primitive_desc() = default;
3289 
3290  primitive_desc(const desc &desc, const engine &e,
3291  const pooling_forward::primitive_desc &hint_fwd_pd,
3292  bool allow_empty = false)
3293  : dnnl::primitive_desc(
3294  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3295 
3296  primitive_desc(const desc &desc, const primitive_attr &attr,
3297  const engine &e,
3298  const pooling_forward::primitive_desc &hint_fwd_pd,
3299  bool allow_empty = false)
3300  : dnnl::primitive_desc(
3301  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3302 
3305  primitive_desc(dnnl_primitive_desc_t pd)
3306  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
3307  dnnl::prop_kind::backward_data) {}
3308 
3310  memory::desc diff_src_desc() const {
3311  return query_md(query::diff_src_md, 0);
3312  }
3313 
3315  memory::desc diff_dst_desc() const {
3316  return query_md(query::diff_dst_md, 0);
3317  }
3318 
3322  memory::desc workspace_desc() const {
3323  return query_md(query::workspace_md, 0);
3324  }
3325  };
3326 
3327  pooling_backward() = default;
3328 
3329  pooling_backward(const primitive_desc &pd) : primitive(pd) {}
3330 };
3331 
3333 
3351 
3354 struct eltwise_forward : public primitive {
3355 
3360  struct desc {
3361  dnnl_eltwise_desc_t data;
3362  desc(prop_kind aprop_kind, algorithm aalgorithm,
3363  const memory::desc &src_desc, float alpha = 0, float beta = 0) {
3365  dnnl::convert_to_c(aprop_kind),
3366  dnnl::convert_to_c(aalgorithm),
3367  &src_desc.data, alpha, beta),
3368  "could not create a eltwise forward descriptor");
3369  }
3370  };
3371 
3373  struct primitive_desc : public dnnl::primitive_desc {
3374  primitive_desc() = default;
3375 
3376  primitive_desc(
3377  const desc &desc, const engine &e, bool allow_empty = false)
3378  : dnnl::primitive_desc(
3379  &desc.data, nullptr, e, nullptr, allow_empty) {}
3380 
3381  primitive_desc(const desc &desc, const primitive_attr &attr,
3382  const engine &e, bool allow_empty = false)
3383  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3384  }
3385 
3388  primitive_desc(dnnl_primitive_desc_t pd)
3389  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
3390  dnnl::prop_kind::forward_training,
3391  dnnl::prop_kind::forward_inference) {}
3392 
3394  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3395 
3397  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3398  };
3399 
3400  eltwise_forward() = default;
3401 
3402  eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
3403 };
3404 
3407 struct eltwise_backward : public primitive {
3408 
3412  struct desc {
3413  dnnl_eltwise_desc_t data;
3414 
3415  desc(algorithm aalgorithm, const memory::desc &diff_data_desc,
3416  const memory::desc &data_desc, float alpha = 0,
3417  float beta = 0) {
3420  dnnl::convert_to_c(aalgorithm),
3421  &diff_data_desc.data, &data_desc.data, alpha, beta),
3422  "could not create a eltwise backward descriptor");
3423  }
3424  };
3425 
3427  struct primitive_desc : public dnnl::primitive_desc {
3428  primitive_desc() = default;
3429 
3430  primitive_desc(const desc &desc, const engine &e,
3431  const eltwise_forward::primitive_desc &hint_fwd_pd,
3432  bool allow_empty = false)
3433  : dnnl::primitive_desc(
3434  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3435 
3436  primitive_desc(const desc &desc, const primitive_attr &attr,
3437  const engine &e,
3438  const eltwise_forward::primitive_desc &hint_fwd_pd,
3439  bool allow_empty = false)
3440  : dnnl::primitive_desc(
3441  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3442 
3445  primitive_desc(dnnl_primitive_desc_t pd)
3446  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
3447  dnnl::prop_kind::backward_data) {}
3448 
3450  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3451 
3453  memory::desc diff_src_desc() const {
3454  return query_md(query::diff_src_md, 0);
3455  }
3456 
3458  memory::desc diff_dst_desc() const {
3459  return query_md(query::diff_dst_md, 0);
3460  }
3461  };
3462 
3463  eltwise_backward() = default;
3464 
3465  eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
3466 };
3467 
3469 
3476 
3479 struct softmax_forward : public primitive {
3480 
3482  struct desc {
3483  dnnl_softmax_desc_t data;
3484 
3488  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3489  int softmax_axis) {
3491  dnnl::convert_to_c(aprop_kind),
3492  &data_desc.data, softmax_axis),
3493  "could not create a softmax forward descriptor");
3494  }
3495  };
3496 
3498  struct primitive_desc : public dnnl::primitive_desc {
3499  primitive_desc() = default;
3500 
3501  primitive_desc(
3502  const desc &desc, const engine &e, bool allow_empty = false)
3503  : dnnl::primitive_desc(
3504  &desc.data, nullptr, e, nullptr, allow_empty) {}
3505 
3506  primitive_desc(const desc &desc, const primitive_attr &attr,
3507  const engine &e, bool allow_empty = false)
3508  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3509  }
3510 
3513  primitive_desc(dnnl_primitive_desc_t pd)
3514  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
3515  dnnl::prop_kind::forward_training,
3516  dnnl::prop_kind::forward_inference) {}
3517 
3519  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3520 
3522  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3523  };
3524 
3525  softmax_forward() = default;
3526 
3527  softmax_forward(const primitive_desc &pd) : primitive(pd) {}
3528 };
3529 
3532 struct softmax_backward : public primitive {
3533 
3535  struct desc {
3536  dnnl_softmax_desc_t data;
3537 
3540  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
3541  int softmax_axis) {
3543  dnnl_softmax_backward_desc_init(&data, &diff_desc.data,
3544  &data_desc.data, softmax_axis),
3545  "could not init a backward softmax descriptor");
3546  }
3547  };
3548 
3550  struct primitive_desc : public dnnl::primitive_desc {
3551  primitive_desc() = default;
3552 
3553  primitive_desc(const desc &desc, const engine &e,
3554  const softmax_forward::primitive_desc &hint_fwd_pd,
3555  bool allow_empty = false)
3556  : dnnl::primitive_desc(
3557  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3558 
3559  primitive_desc(const desc &desc, const primitive_attr &attr,
3560  const engine &e,
3561  const softmax_forward::primitive_desc &hint_fwd_pd,
3562  bool allow_empty = false)
3563  : dnnl::primitive_desc(
3564  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3565 
3568  primitive_desc(dnnl_primitive_desc_t pd)
3569  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
3570  dnnl::prop_kind::backward_data) {}
3571 
3573  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3574 
3576  memory::desc diff_src_desc() const {
3577  return query_md(query::diff_src_md, 0);
3578  }
3579 
3581  memory::desc diff_dst_desc() const {
3582  return query_md(query::diff_dst_md, 0);
3583  }
3584  };
3585 
3586  softmax_backward() = default;
3587 
3588  softmax_backward(const primitive_desc &pd) : primitive(pd) {}
3589 };
3590 
3592 
3610 
3613 struct batch_normalization_forward : public primitive {
3614 
3616  struct desc {
3618 
3627  desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon,
3628  normalization_flags flags) {
3631  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3632  epsilon, convert_to_c(flags)),
3633  "could not create a batch normalization forward "
3634  "descriptor");
3635  }
3636  };
3637 
3639  struct primitive_desc : public dnnl::primitive_desc {
3640  primitive_desc() = default;
3641 
3642  primitive_desc(
3643  const desc &desc, const engine &e, bool allow_empty = false)
3644  : dnnl::primitive_desc(
3645  &desc.data, nullptr, e, nullptr, allow_empty) {}
3646 
3647  primitive_desc(const desc &desc, const primitive_attr &attr,
3648  const engine &e, bool allow_empty = false)
3649  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3650  }
3651 
3654  primitive_desc(dnnl_primitive_desc_t pd)
3655  : dnnl::primitive_desc(pd,
3656  dnnl::primitive::kind::batch_normalization,
3657  dnnl::prop_kind::forward_training,
3658  dnnl::prop_kind::forward_inference) {}
3659 
3661  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3662 
3664  memory::desc weights_desc() const {
3665  return query_md(query::weights_md, 0);
3666  }
3667 
3669  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3670 
3674  memory::desc workspace_desc() const {
3675  return query_md(query::workspace_md, 0);
3676  }
3677 
3679  memory::desc mean_desc() const { return stat_desc(mean); }
3680 
3682  memory::desc variance_desc() const { return stat_desc(var); }
3683 
3684  private:
3685  enum {
3686  mean = 1,
3687  var = 2,
3688  };
3689  memory::desc stat_desc(int kind) const {
3693  dnnl::convert_to_c(query::batch_normalization_d), 0,
3694  &p),
3695  "could not get a batch-normalization descriptor");
3696  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
3697  : query::dst_md,
3698  kind);
3699  }
3700  };
3701 
3702  batch_normalization_forward() = default;
3703 
3704  batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
3705 };
3706 
3709 struct batch_normalization_backward : public primitive {
3710 
3712  struct desc {
3714 
3723  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3724  const memory::desc &data_desc, float epsilon,
3725  normalization_flags flags) {
3727  dnnl::convert_to_c(aprop_kind),
3728  &diff_data_desc.data, &data_desc.data,
3729  epsilon, convert_to_c(flags)),
3730  "could not create a batch normalization backward "
3731  "descriptor");
3732  }
3733  };
3734 
3736  struct primitive_desc : public dnnl::primitive_desc {
3737  primitive_desc() = default;
3738 
3739  primitive_desc(const desc &desc, const engine &e,
3740  const batch_normalization_forward::primitive_desc &hint_fwd_pd,
3741  bool allow_empty = false)
3742  : dnnl::primitive_desc(
3743  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3744 
3745  primitive_desc(const desc &desc, const primitive_attr &attr,
3746  const engine &e,
3747  const batch_normalization_forward::primitive_desc &hint_fwd_pd,
3748  bool allow_empty = false)
3749  : dnnl::primitive_desc(
3750  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3751 
3754  primitive_desc(dnnl_primitive_desc_t pd)
3755  : dnnl::primitive_desc(pd,
3756  dnnl::primitive::kind::batch_normalization,
3757  dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
3758  }
3759 
3761  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3762 
3764  memory::desc mean_desc() const { return query_md(query::src_md, 1); }
3765 
3767  memory::desc variance_desc() const {
3768  return query_md(query::src_md, 2);
3769  }
3770 
3772  memory::desc weights_desc() const {
3773  return query_md(query::weights_md, 0);
3774  }
3775 
3777  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3778 
3780  memory::desc diff_dst_desc() const {
3781  return query_md(query::diff_dst_md, 0);
3782  }
3783 
3787  memory::desc workspace_desc() const {
3788  return query_md(query::workspace_md, 0);
3789  }
3790 
3792  memory::desc diff_src_desc() const {
3793  return query_md(query::diff_src_md, 0);
3794  }
3795 
3797  memory::desc diff_weights_desc() const {
3798  return query_md(query::diff_weights_md, 0);
3799  }
3800  };
3801 
3802  batch_normalization_backward() = default;
3803 
3804  batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
3805 };
3806 
3808 
3827 
3830 struct layer_normalization_forward : public primitive {
3831 
3833  struct desc {
3835 
3844  desc(prop_kind aprop_kind, const memory::desc &src_desc,
3845  const memory::desc &stat_desc, float epsilon,
3846  normalization_flags flags) {
3849  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3850  &stat_desc.data, epsilon, convert_to_c(flags)),
3851  "could not create a layer normalization forward "
3852  "descriptor");
3853  }
3854 
3855  desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon,
3856  normalization_flags flags) {
3859  dnnl::convert_to_c(aprop_kind), &src_desc.data,
3860  nullptr, epsilon, convert_to_c(flags)),
3861  "could not create a layer normalization forward "
3862  "descriptor");
3863  }
3864  };
3865 
3867  struct primitive_desc : public dnnl::primitive_desc {
3868  primitive_desc() = default;
3869 
3870  primitive_desc(
3871  const desc &desc, const engine &e, bool allow_empty = false)
3872  : dnnl::primitive_desc(
3873  &desc.data, nullptr, e, nullptr, allow_empty) {}
3874 
3875  primitive_desc(const desc &desc, const primitive_attr &attr,
3876  const engine &e, bool allow_empty = false)
3877  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
3878  }
3879 
3882  primitive_desc(dnnl_primitive_desc_t pd)
3883  : dnnl::primitive_desc(pd,
3884  dnnl::primitive::kind::layer_normalization,
3885  dnnl::prop_kind::forward_training,
3886  dnnl::prop_kind::forward_inference) {}
3887 
3889  memory::desc src_desc() const { return query_md(query::src_md, 0); }
3890 
3892  memory::desc weights_desc() const {
3893  return query_md(query::weights_md, 0);
3894  }
3895 
3897  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
3898 
3900  memory::desc mean_desc() const { return stat_desc(mean); }
3901 
3903  memory::desc variance_desc() const { return stat_desc(var); }
3904 
3908  memory::desc workspace_desc() const {
3909  return query_md(query::workspace_md, 0);
3910  }
3911 
3912  private:
3913  enum {
3914  mean = 1,
3915  var = 2,
3916  };
3917  memory::desc stat_desc(int kind) const {
3921  dnnl::convert_to_c(query::layer_normalization_d), 0,
3922  &p),
3923  "could not get a layer-normalization descriptor");
3924  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
3925  : query::dst_md,
3926  kind);
3927  }
3928  };
3929 
3930  layer_normalization_forward() = default;
3931 
3932  layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
3933 };
3934 
3937 struct layer_normalization_backward : public primitive {
3938 
3940  struct desc {
3942 
3951  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3952  const memory::desc &data_desc, const memory::desc &stat_desc,
3953  float epsilon, normalization_flags flags) {
3956  dnnl::convert_to_c(aprop_kind),
3957  &diff_data_desc.data, &data_desc.data,
3958  &stat_desc.data, epsilon, convert_to_c(flags)),
3959  "could not create a layer normalization backward "
3960  "descriptor");
3961  }
3962 
3963  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
3964  const memory::desc &data_desc, float epsilon,
3965  normalization_flags flags) {
3967  dnnl::convert_to_c(aprop_kind),
3968  &diff_data_desc.data, &data_desc.data,
3969  nullptr, epsilon, convert_to_c(flags)),
3970  "could not create a layer normalization backward "
3971  "descriptor");
3972  }
3973  };
3974 
3976  struct primitive_desc : public dnnl::primitive_desc {
3977  primitive_desc() = default;
3978 
3979  primitive_desc(const desc &desc, const engine &e,
3980  const layer_normalization_forward::primitive_desc &hint_fwd_pd,
3981  bool allow_empty = false)
3982  : dnnl::primitive_desc(
3983  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
3984 
3985  primitive_desc(const desc &desc, const primitive_attr &attr,
3986  const engine &e,
3987  const layer_normalization_forward::primitive_desc &hint_fwd_pd,
3988  bool allow_empty = false)
3989  : dnnl::primitive_desc(
3990  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
3991 
3994  primitive_desc(dnnl_primitive_desc_t pd)
3995  : dnnl::primitive_desc(pd,
3996  dnnl::primitive::kind::layer_normalization,
3997  dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
3998  }
3999 
4001  memory::desc src_desc() const { return query_md(query::src_md, 0); }
4002 
4004  memory::desc mean_desc() const { return query_md(query::src_md, 1); }
4005 
4007  memory::desc variance_desc() const {
4008  return query_md(query::src_md, 2);
4009  }
4010 
4012  memory::desc weights_desc() const {
4013  return query_md(query::weights_md, 0);
4014  }
4015 
4017  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
4018 
4020  memory::desc diff_dst_desc() const {
4021  return query_md(query::diff_dst_md, 0);
4022  }
4023 
4025  memory::desc diff_src_desc() const {
4026  return query_md(query::diff_src_md, 0);
4027  }
4028 
4030  memory::desc diff_weights_desc() const {
4031  return query_md(query::diff_weights_md, 0);
4032  }
4033 
4037  memory::desc workspace_desc() const {
4038  return query_md(query::workspace_md, 0);
4039  }
4040  };
4041 
4042  layer_normalization_backward() = default;
4043 
4044  layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
4045 };
4046 
4048 
4055 
4058 struct inner_product_forward : public primitive {
4059 
4069  struct desc {
4071  desc(prop_kind aprop_kind, const memory::desc &src_desc,
4072  const memory::desc &weights_desc, const memory::desc &bias_desc,
4073  const memory::desc &dst_desc) {
4075  dnnl::convert_to_c(aprop_kind),
4076  &src_desc.data, &weights_desc.data,
4077  &bias_desc.data, &dst_desc.data),
4078  "could not create a inner product forward descriptor");
4079  }
4080 
4081  desc(prop_kind aprop_kind, const memory::desc &src_desc,
4082  const memory::desc &weights_desc,
4083  const memory::desc &dst_desc) {
4086  dnnl::convert_to_c(aprop_kind), &src_desc.data,
4087  &weights_desc.data, nullptr, &dst_desc.data),
4088  "could not create a inner product forward descriptor");
4089  }
4090  };
4091 
4093  struct primitive_desc : public dnnl::primitive_desc {
4094  primitive_desc() = default;
4095 
4096  primitive_desc(
4097  const desc &desc, const engine &e, bool allow_empty = false)
4098  : dnnl::primitive_desc(
4099  &desc.data, nullptr, e, nullptr, allow_empty) {}
4100 
4101  primitive_desc(const desc &desc, const primitive_attr &attr,
4102  const engine &e, bool allow_empty = false)
4103  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr, allow_empty) {
4104  }
4105 
4108  primitive_desc(dnnl_primitive_desc_t pd)
4109  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
4110  dnnl::prop_kind::forward_training,
4111  dnnl::prop_kind::forward_inference) {}
4112 
4114  memory::desc src_desc() const { return query_md(query::src_md, 0); }
4115 
4117  memory::desc weights_desc() const {
4118  return query_md(query::weights_md, 0);
4119  }
4120 
4125  memory::desc bias_desc() const {
4126  return query_md(query::weights_md, 1);
4127  }
4128 
4130  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
4131  };
4132 
4133  inner_product_forward() = default;
4134 
4135  inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
4136 };
4137 
4140 struct inner_product_backward_data : public primitive {
4141 
4147  struct desc {
4149  desc(const memory::desc &diff_src_desc,
4150  const memory::desc &weights_desc,
4151  const memory::desc &diff_dst_desc) {
4153  &diff_src_desc.data, &weights_desc.data,
4154  &diff_dst_desc.data),
4155  "could not create a inner product backward data "
4156  "descriptor");
4157  }
4158  };
4159 
4162  struct primitive_desc : public dnnl::primitive_desc {
4163  primitive_desc() = default;
4164 
4165  primitive_desc(const desc &desc, const engine &e,
4166  const inner_product_forward::primitive_desc &hint_fwd_pd,
4167  bool allow_empty = false)
4168  : dnnl::primitive_desc(
4169  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4170 
4171  primitive_desc(const desc &desc, const primitive_attr &attr,
4172  const engine &e,
4173  const inner_product_forward::primitive_desc &hint_fwd_pd,
4174  bool allow_empty = false)
4175  : dnnl::primitive_desc(
4176  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4177 
4180  primitive_desc(dnnl_primitive_desc_t pd)
4181  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
4182  dnnl::prop_kind::backward_data) {}
4183 
4185  memory::desc diff_src_desc() const {
4186  return query_md(query::diff_src_md, 0);
4187  }
4188 
4190  memory::desc weights_desc() const {
4191  return query_md(query::weights_md, 0);
4192  }
4193 
4195  memory::desc diff_dst_desc() const {
4196  return query_md(query::diff_dst_md, 0);
4197  }
4198  };
4199 
4200  inner_product_backward_data() = default;
4201 
4202  inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
4203 };
4204 
4207 struct inner_product_backward_weights : public primitive {
4208 
4214  struct desc {
4216  desc(const memory::desc &src_desc,
4217  const memory::desc &diff_weights_desc,
4218  const memory::desc &diff_bias_desc,
4219  const memory::desc &diff_dst_desc) {
4222  &src_desc.data, &diff_weights_desc.data,
4223  &diff_bias_desc.data, &diff_dst_desc.data),
4224  "could not create a inner product backward weights "
4225  "descriptor");
4226  }
4227  desc(const memory::desc &src_desc,
4228  const memory::desc &diff_weights_desc,
4229  const memory::desc &diff_dst_desc) {
4232  &src_desc.data, &diff_weights_desc.data, nullptr,
4233  &diff_dst_desc.data),
4234  "could not create a inner product backward weights "
4235  "descriptor");
4236  }
4237  };
4238 
4241  struct primitive_desc : public dnnl::primitive_desc {
4242  primitive_desc() = default;
4243 
4244  primitive_desc(const desc &desc, const engine &e,
4245  const inner_product_forward::primitive_desc &hint_fwd_pd,
4246  bool allow_empty = false)
4247  : dnnl::primitive_desc(
4248  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4249 
4250  primitive_desc(const desc &desc, const primitive_attr &attr,
4251  const engine &e,
4252  const inner_product_forward::primitive_desc &hint_fwd_pd,
4253  bool allow_empty = false)
4254  : dnnl::primitive_desc(
4255  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4256 
4259  primitive_desc(dnnl_primitive_desc_t cpd)
4260  : dnnl::primitive_desc(cpd, dnnl::primitive::kind::inner_product,
4261  dnnl::prop_kind::backward_weights) {}
4262 
4264  memory::desc src_desc() const { return query_md(query::src_md, 0); }
4265 
4267  memory::desc diff_weights_desc() const {
4268  return query_md(query::diff_weights_md, 0);
4269  }
4270 
4272  memory::desc diff_bias_desc() const {
4273  return query_md(query::diff_weights_md, 1);
4274  }
4275 
4277  memory::desc diff_dst_desc() const {
4278  return query_md(query::diff_dst_md, 0);
4279  }
4280  };
4281 
4282  inner_product_backward_weights() = default;
4283 
4284  inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
4285 };
4286 
4288 
4295 
4296 struct rnn_primitive_desc_base : public primitive_desc {
4297  using primitive_desc::primitive_desc;
4298 
4299  rnn_primitive_desc_base() = default;
4300 
4301 protected:
4302  // Constructs an RNN primitive descriptor from a C counterpart while
4303  // checking that it actually describes the expected primitive.
4304  rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
4305  dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
4306  dnnl::algorithm cell_kind) {
4308  dnnl_status_t rc;
4311  rc, "could not retrieve rnn_desc from a primitive descriptor");
4312 
4313  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
4314  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
4315  dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
4316 
4317  bool ok = rnn_d->primitive_kind == dnnl_rnn
4318  && (rnn_d->prop_kind == c_prop_kind1
4319  || rnn_d->prop_kind == c_prop_kind2)
4320  && rnn_d->cell_kind == c_cell_kind;
4321 
4322  if (!ok) throw error(dnnl_invalid_arguments, "rnn descriptor mismatch");
4323 
4324  reset_with_clone(pd);
4325  }
4326 
4327  // Constructs an RNN primitive descriptor from a C counterpart while
4328  // checking that it actually describes the expected primitive.
4329  rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind prop_kind,
4330  dnnl::algorithm cell_kind)
4331  : rnn_primitive_desc_base(pd, prop_kind, prop_kind, cell_kind) {}
4332 };
4333 
4337 struct vanilla_rnn_forward : public primitive {
4338 
4340  struct desc {
4341  dnnl_rnn_desc_t data;
4342 
4361  desc(prop_kind aprop_kind, algorithm activation,
4362  rnn_direction direction, const memory::desc &src_layer_desc,
4363  const memory::desc &src_iter_desc,
4364  const memory::desc &weights_layer_desc,
4365  const memory::desc &weights_iter_desc,
4366  const memory::desc &bias_desc,
4367  const memory::desc &dst_layer_desc,
4368  const memory::desc &dst_iter_desc,
4369  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
4370  float beta = 0.0f) {
4373  dnnl::convert_to_c(aprop_kind),
4374  dnnl::convert_to_c(activation),
4375  dnnl::convert_to_c(direction), &src_layer_desc.data,
4376  &src_iter_desc.data, &weights_layer_desc.data,
4377  &weights_iter_desc.data, &bias_desc.data,
4378  &dst_layer_desc.data, &dst_iter_desc.data,
4379  dnnl::convert_to_c(flags), alpha, beta),
4380  "could not create an RNN forward descriptor");
4381  }
4382  };
4383 
4385  struct primitive_desc : public rnn_primitive_desc_base {
4386  primitive_desc() = default;
4387 
4388  primitive_desc(
4389  const desc &desc, const engine &e, bool allow_empty = false)
4390  : rnn_primitive_desc_base(
4391  &desc.data, nullptr, e, nullptr, allow_empty) {}
4392 
4393  primitive_desc(const desc &desc, const primitive_attr &attr,
4394  const engine &e, bool allow_empty = false)
4395  : rnn_primitive_desc_base(
4396  &desc.data, &attr, e, nullptr, allow_empty) {}
4397 
4400  primitive_desc(dnnl_primitive_desc_t pd)
4401  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
4403  dnnl::algorithm::vanilla_rnn) {}
4404 
4406  memory::desc src_layer_desc() const {
4407  return query_md(query::src_md, 0);
4408  }
4409 
4414  memory::desc src_iter_desc() const {
4415  return query_md(query::src_md, 1);
4416  }
4417 
4419  memory::desc weights_layer_desc() const {
4420  return query_md(query::weights_md, 0);
4421  }
4422 
4424  memory::desc weights_iter_desc() const {
4425  return query_md(query::weights_md, 1);
4426  }
4427 
4432  memory::desc bias_desc() const {
4433  return query_md(query::weights_md, 2);
4434  }
4435 
4437  memory::desc dst_layer_desc() const {
4438  return query_md(query::dst_md, 0);
4439  }
4440 
4445  memory::desc dst_iter_desc() const {
4446  return query_md(query::dst_md, 1);
4447  }
4448 
4452  memory::desc workspace_desc() const {
4453  return query_md(query::workspace_md, 0);
4454  }
4455  };
4456 
4457  vanilla_rnn_forward() = default;
4458 
4459  vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
4460 };
4461 
4465 struct vanilla_rnn_backward : public primitive {
4466 
4468  struct desc {
4469  dnnl_rnn_desc_t data;
4470 
4488  desc(prop_kind aprop_kind, algorithm activation,
4489  rnn_direction direction, const memory::desc &src_layer_desc,
4490  const memory::desc &src_iter_desc,
4491  const memory::desc &weights_layer_desc,
4492  const memory::desc &weights_iter_desc,
4493  const memory::desc &bias_desc,
4494  const memory::desc &dst_layer_desc,
4495  const memory::desc &dst_iter_desc,
4496  const memory::desc &diff_src_layer_desc,
4497  const memory::desc &diff_src_iter_desc,
4498  const memory::desc &diff_weights_layer_desc,
4499  const memory::desc &diff_weights_iter_desc,
4500  const memory::desc &diff_bias_desc,
4501  const memory::desc &diff_dst_layer_desc,
4502  const memory::desc &diff_dst_iter_desc,
4503  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
4504  float beta = 0.0f) {
4507  dnnl::convert_to_c(aprop_kind),
4508  dnnl::convert_to_c(activation),
4509  dnnl::convert_to_c(direction), &src_layer_desc.data,
4510  &src_iter_desc.data, &weights_layer_desc.data,
4511  &weights_iter_desc.data, &bias_desc.data,
4512  &dst_layer_desc.data, &dst_iter_desc.data,
4513  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
4514  &diff_weights_layer_desc.data,
4515  &diff_weights_iter_desc.data, &diff_bias_desc.data,
4516  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
4517  dnnl::convert_to_c(flags), alpha, beta),
4518  "could not create an RNN backward descriptor");
4519  }
4520  };
4521 
4523  struct primitive_desc : public rnn_primitive_desc_base {
4524  primitive_desc() = default;
4525 
4526  primitive_desc(const desc &desc, const engine &e,
4527  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
4528  bool allow_empty = false)
4529  : rnn_primitive_desc_base(
4530  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4531 
4532  primitive_desc(const desc &desc, const primitive_attr &attr,
4533  const engine &e,
4534  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
4535  bool allow_empty = false)
4536  : rnn_primitive_desc_base(
4537  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4538 
4541  primitive_desc(dnnl_primitive_desc_t pd)
4542  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
4543  dnnl::algorithm::vanilla_rnn) {}
4544 
4546  memory::desc src_layer_desc() const {
4547  return query_md(query::src_md, 0);
4548  }
4549 
4554  memory::desc src_iter_desc() const {
4555  return query_md(query::src_md, 1);
4556  }
4557 
4559  memory::desc weights_layer_desc() const {
4560  return query_md(query::weights_md, 0);
4561  }
4562 
4564  memory::desc weights_iter_desc() const {
4565  return query_md(query::weights_md, 1);
4566  }
4567 
4572  memory::desc bias_desc() const {
4573  return query_md(query::weights_md, 2);
4574  }
4575 
4577  memory::desc dst_layer_desc() const {
4578  return query_md(query::dst_md, 0);
4579  }
4580 
4585  memory::desc dst_iter_desc() const {
4586  return query_md(query::dst_md, 1);
4587  }
4588 
4592  memory::desc workspace_desc() const {
4593  return query_md(query::workspace_md, 0);
4594  }
4595 
4597  memory::desc diff_src_layer_desc() const {
4598  return query_md(query::diff_src_md, 0);
4599  }
4600 
4605  memory::desc diff_src_iter_desc() const {
4606  return query_md(query::diff_src_md, 1);
4607  }
4608 
4610  memory::desc diff_weights_layer_desc() const {
4611  return query_md(query::diff_weights_md, 0);
4612  }
4613 
4615  memory::desc diff_weights_iter_desc() const {
4616  return query_md(query::diff_weights_md, 1);
4617  }
4618 
4620  memory::desc diff_bias_desc() const {
4621  return query_md(query::diff_weights_md, 2);
4622  }
4623 
4625  memory::desc diff_dst_layer_desc() const {
4626  return query_md(query::diff_dst_md, 0);
4627  }
4628 
4633  memory::desc diff_dst_iter_desc() const {
4634  return query_md(query::diff_dst_md, 1);
4635  }
4636  };
4637 
4638  vanilla_rnn_backward() = default;
4639 
4640  vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
4641 };
4642 
4646 struct lstm_forward : public primitive {
4647 
4649  struct desc {
4650  dnnl_rnn_desc_t data;
4651 
4669  desc(prop_kind aprop_kind, rnn_direction direction,
4670  const memory::desc &src_layer_desc,
4671  const memory::desc &src_iter_desc,
4672  const memory::desc &src_iter_c_desc,
4673  const memory::desc &weights_layer_desc,
4674  const memory::desc &weights_iter_desc,
4675  const memory::desc &bias_desc,
4676  const memory::desc &dst_layer_desc,
4677  const memory::desc &dst_iter_desc,
4678  const memory::desc &dst_iter_c_desc,
4679  rnn_flags flags = rnn_flags::undef) {
4682  dnnl::convert_to_c(aprop_kind),
4683  dnnl::convert_to_c(direction), &src_layer_desc.data,
4684  &src_iter_desc.data, &src_iter_c_desc.data,
4685  &weights_layer_desc.data, &weights_iter_desc.data,
4686  &bias_desc.data, &dst_layer_desc.data,
4687  &dst_iter_desc.data, &dst_iter_c_desc.data,
4688  dnnl::convert_to_c(flags)),
4689  "could not create an LSTM forward descriptor");
4690  }
4691  };
4692 
4694  struct primitive_desc : public rnn_primitive_desc_base {
4695  primitive_desc() = default;
4696 
4697  primitive_desc(
4698  const desc &desc, const engine &e, bool allow_empty = false)
4699  : rnn_primitive_desc_base(
4700  &desc.data, nullptr, e, nullptr, allow_empty) {}
4701 
4702  primitive_desc(const desc &desc, const primitive_attr &attr,
4703  const engine &e, bool allow_empty = false)
4704  : rnn_primitive_desc_base(
4705  &desc.data, &attr, e, nullptr, allow_empty) {}
4706 
4709  primitive_desc(dnnl_primitive_desc_t pd)
4710  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
4712  dnnl::algorithm::vanilla_lstm) {}
4713 
4715  memory::desc src_layer_desc() const {
4716  return query_md(query::src_md, 0);
4717  }
4718 
4723  memory::desc src_iter_desc() const {
4724  return query_md(query::src_md, 1);
4725  }
4726 
4728  memory::desc src_iter_c_desc() const {
4729  return query_md(query::src_md, 2);
4730  }
4731 
4733  memory::desc weights_layer_desc() const {
4734  return query_md(query::weights_md, 0);
4735  }
4736 
4738  memory::desc weights_iter_desc() const {
4739  return query_md(query::weights_md, 1);
4740  }
4741 
4746  memory::desc bias_desc() const {
4747  return query_md(query::weights_md, 2);
4748  }
4749 
4751  memory::desc dst_layer_desc() const {
4752  return query_md(query::dst_md, 0);
4753  }
4754 
4759  memory::desc dst_iter_desc() const {
4760  return query_md(query::dst_md, 1);
4761  }
4762 
4764  memory::desc dst_iter_c_desc() const {
4765  return query_md(query::dst_md, 2);
4766  }
4767 
4771  memory::desc workspace_desc() const {
4772  return query_md(query::workspace_md, 0);
4773  }
4774  };
4775 
4776  lstm_forward() = default;
4777 
4778  lstm_forward(const primitive_desc &pd) : primitive(pd) {}
4779 };
4780 
4784 struct lstm_backward : public primitive {
4785 
4787  struct desc {
4788  dnnl_rnn_desc_t data;
4789 
4808  desc(prop_kind aprop_kind, rnn_direction direction,
4809  const memory::desc &src_layer_desc,
4810  const memory::desc &src_iter_desc,
4811  const memory::desc &src_iter_c_desc,
4812  const memory::desc &weights_layer_desc,
4813  const memory::desc &weights_iter_desc,
4814  const memory::desc &bias_desc,
4815  const memory::desc &dst_layer_desc,
4816  const memory::desc &dst_iter_desc,
4817  const memory::desc &dst_iter_c_desc,
4818  const memory::desc &diff_src_layer_desc,
4819  const memory::desc &diff_src_iter_desc,
4820  const memory::desc &diff_src_iter_c_desc,
4821  const memory::desc &diff_weights_layer_desc,
4822  const memory::desc &diff_weights_iter_desc,
4823  const memory::desc &diff_bias_desc,
4824  const memory::desc &diff_dst_layer_desc,
4825  const memory::desc &diff_dst_iter_desc,
4826  const memory::desc &diff_dst_iter_c_desc,
4827  rnn_flags flags = rnn_flags::undef) {
4830  dnnl::convert_to_c(aprop_kind),
4831  dnnl::convert_to_c(direction), &src_layer_desc.data,
4832  &src_iter_desc.data, &src_iter_c_desc.data,
4833  &weights_layer_desc.data, &weights_iter_desc.data,
4834  &bias_desc.data, &dst_layer_desc.data,
4835  &dst_iter_desc.data, &dst_iter_c_desc.data,
4836  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
4837  &diff_src_iter_c_desc.data,
4838  &diff_weights_layer_desc.data,
4839  &diff_weights_iter_desc.data, &diff_bias_desc.data,
4840  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
4841  &diff_dst_iter_c_desc.data,
4842  dnnl::convert_to_c(flags)),
4843  "could not create an LSTM backward descriptor");
4844  }
4845  };
4846 
4848  struct primitive_desc : public rnn_primitive_desc_base {
4849  primitive_desc() = default;
4850 
4851  primitive_desc(const desc &desc, const engine &e,
4852  const lstm_forward::primitive_desc &hint_fwd_pd,
4853  bool allow_empty = false)
4854  : rnn_primitive_desc_base(
4855  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
4856 
4857  primitive_desc(const desc &desc, const primitive_attr &attr,
4858  const engine &e,
4859  const lstm_forward::primitive_desc &hint_fwd_pd,
4860  bool allow_empty = false)
4861  : rnn_primitive_desc_base(
4862  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
4863 
4866  primitive_desc(dnnl_primitive_desc_t pd)
4867  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
4868  dnnl::algorithm::vanilla_lstm) {}
4869 
4871  memory::desc src_layer_desc() const {
4872  return query_md(query::src_md, 0);
4873  }
4874 
4879  memory::desc src_iter_desc() const {
4880  return query_md(query::src_md, 1);
4881  }
4882 
4884  memory::desc src_iter_c_desc() const {
4885  return query_md(query::src_md, 2);
4886  }
4887 
4889  memory::desc weights_layer_desc() const {
4890  return query_md(query::weights_md, 0);
4891  }
4892 
4894  memory::desc weights_iter_desc() const {
4895  return query_md(query::weights_md, 1);
4896  }
4897 
4902  memory::desc bias_desc() const {
4903  return query_md(query::weights_md, 2);
4904  }
4905 
4907  memory::desc dst_layer_desc() const {
4908  return query_md(query::dst_md, 0);
4909  }
4910 
4915  memory::desc dst_iter_desc() const {
4916  return query_md(query::dst_md, 1);
4917  }
4918 
4920  memory::desc dst_iter_c_desc() const {
4921  return query_md(query::dst_md, 2);
4922  }
4923 
4927  memory::desc workspace_desc() const {
4928  return query_md(query::workspace_md, 0);
4929  }
4930 
4932  memory::desc diff_src_layer_desc() const {
4933  return query_md(query::diff_src_md, 0);
4934  }
4935 
4940  memory::desc diff_src_iter_desc() const {
4941  return query_md(query::diff_src_md, 1);
4942  }
4943 
4945  memory::desc diff_src_iter_c_desc() const {
4946  return query_md(query::diff_src_md, 2);
4947  }
4948 
4950  memory::desc diff_weights_layer_desc() const {
4951  return query_md(query::diff_weights_md, 0);
4952  }
4953 
4955  memory::desc diff_weights_iter_desc() const {
4956  return query_md(query::diff_weights_md, 1);
4957  }
4958 
4960  memory::desc diff_bias_desc() const {
4961  return query_md(query::diff_weights_md, 2);
4962  }
4963 
4965  memory::desc diff_dst_layer_desc() const {
4966  return query_md(query::diff_dst_md, 0);
4967  }
4968 
4973  memory::desc diff_dst_iter_desc() const {
4974  return query_md(query::diff_dst_md, 1);
4975  }
4976 
4978  memory::desc diff_dst_iter_c_desc() const {
4979  return query_md(query::diff_dst_md, 2);
4980  }
4981  };
4982 
4983  lstm_backward() = default;
4984 
4985  // With last iteration (with and without input src_iter)
4986  lstm_backward(const primitive_desc &pd) : primitive(pd) {}
4987 };
4988 
4992 struct gru_forward : public primitive {
4993 
4995  struct desc {
4996  dnnl_rnn_desc_t data;
4997 
5015  desc(prop_kind aprop_kind, rnn_direction direction,
5016  const memory::desc &src_layer_desc,
5017  const memory::desc &src_iter_desc,
5018  const memory::desc &weights_layer_desc,
5019  const memory::desc &weights_iter_desc,
5020  const memory::desc &bias_desc,
5021  const memory::desc &dst_layer_desc,
5022  const memory::desc &dst_iter_desc,
5023  rnn_flags flags = rnn_flags::undef) {
5026  dnnl::convert_to_c(aprop_kind),
5027  dnnl::convert_to_c(direction), &src_layer_desc.data,
5028  &src_iter_desc.data, &weights_layer_desc.data,
5029  &weights_iter_desc.data, &bias_desc.data,
5030  &dst_layer_desc.data, &dst_iter_desc.data,
5031  dnnl::convert_to_c(flags)),
5032  "could not create a GRU forward descriptor");
5033  }
5034  };
5035 
5037  struct primitive_desc : public rnn_primitive_desc_base {
5038  primitive_desc() = default;
5039 
5040  primitive_desc(
5041  const desc &desc, const engine &e, bool allow_empty = false)
5042  : rnn_primitive_desc_base(
5043  &desc.data, nullptr, e, nullptr, allow_empty) {}
5044 
5045  primitive_desc(const desc &desc, const primitive_attr &attr,
5046  const engine &e, bool allow_empty = false)
5047  : rnn_primitive_desc_base(
5048  &desc.data, &attr, e, nullptr, allow_empty) {}
5049 
5052  primitive_desc(dnnl_primitive_desc_t pd)
5053  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
5055  dnnl::algorithm::vanilla_gru) {}
5056 
5058  memory::desc src_layer_desc() const {
5059  return query_md(query::src_md, 0);
5060  }
5061 
5066  memory::desc src_iter_desc() const {
5067  return query_md(query::src_md, 1);
5068  }
5069 
5071  memory::desc weights_layer_desc() const {
5072  return query_md(query::weights_md, 0);
5073  }
5074 
5076  memory::desc weights_iter_desc() const {
5077  return query_md(query::weights_md, 1);
5078  }
5079 
5084  memory::desc bias_desc() const {
5085  return query_md(query::weights_md, 2);
5086  }
5087 
5089  memory::desc dst_layer_desc() const {
5090  return query_md(query::dst_md, 0);
5091  }
5092 
5097  memory::desc dst_iter_desc() const {
5098  return query_md(query::dst_md, 1);
5099  }
5100 
5104  memory::desc workspace_desc() const {
5105  return query_md(query::workspace_md, 0);
5106  }
5107  };
5108 
5109  gru_forward() = default;
5110 
5111  gru_forward(const primitive_desc &pd) : primitive(pd) {}
5112 };
5113 
5117 struct gru_backward : public primitive {
5118 
5120  struct desc {
5121  dnnl_rnn_desc_t data;
5122 
5138  desc(prop_kind aprop_kind, rnn_direction direction,
5139  const memory::desc &src_layer_desc,
5140  const memory::desc &src_iter_desc,
5141  const memory::desc &weights_layer_desc,
5142  const memory::desc &weights_iter_desc,
5143  const memory::desc &bias_desc,
5144  const memory::desc &dst_layer_desc,
5145  const memory::desc &dst_iter_desc,
5146  const memory::desc &diff_src_layer_desc,
5147  const memory::desc &diff_src_iter_desc,
5148  const memory::desc &diff_weights_layer_desc,
5149  const memory::desc &diff_weights_iter_desc,
5150  const memory::desc &diff_bias_desc,
5151  const memory::desc &diff_dst_layer_desc,
5152  const memory::desc &diff_dst_iter_desc,
5153  rnn_flags flags = rnn_flags::undef) {
5156  dnnl::convert_to_c(aprop_kind),
5157  dnnl::convert_to_c(direction), &src_layer_desc.data,
5158  &src_iter_desc.data, &weights_layer_desc.data,
5159  &weights_iter_desc.data, &bias_desc.data,
5160  &dst_layer_desc.data, &dst_iter_desc.data,
5161  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
5162  &diff_weights_layer_desc.data,
5163  &diff_weights_iter_desc.data, &diff_bias_desc.data,
5164  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
5165  dnnl::convert_to_c(flags)),
5166  "could not create an GRU backward descriptor");
5167  }
5168  };
5169 
5171  struct primitive_desc : public rnn_primitive_desc_base {
5172  primitive_desc() = default;
5173 
5174  primitive_desc(const desc &desc, const engine &e,
5175  const gru_forward::primitive_desc &hint_fwd_pd,
5176  bool allow_empty = false)
5177  : rnn_primitive_desc_base(
5178  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
5179 
5180  primitive_desc(const desc &desc, const primitive_attr &attr,
5181  const engine &e, const gru_forward::primitive_desc &hint_fwd_pd,
5182  bool allow_empty = false)
5183  : rnn_primitive_desc_base(
5184  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
5185 
5188  primitive_desc(dnnl_primitive_desc_t pd)
5189  : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
5190  dnnl::algorithm::vanilla_gru) {}
5191 
5193  memory::desc src_layer_desc() const {
5194  return query_md(query::src_md, 0);
5195  }
5196 
5201  memory::desc src_iter_desc() const {
5202  return query_md(query::src_md, 1);
5203  }
5204 
5206  memory::desc weights_layer_desc() const {
5207  return query_md(query::weights_md, 0);
5208  }
5209 
5211  memory::desc weights_iter_desc() const {
5212  return query_md(query::weights_md, 1);
5213  }
5214 
5219  memory::desc bias_desc() const {
5220  return query_md(query::weights_md, 2);
5221  }
5222 
5224  memory::desc dst_layer_desc() const {
5225  return query_md(query::dst_md, 0);
5226  }
5227 
5232  memory::desc dst_iter_desc() const {
5233  return query_md(query::dst_md, 1);
5234  }
5235 
5239  memory::desc workspace_desc() const {
5240  return query_md(query::workspace_md, 0);
5241  }
5242 
5244  memory::desc diff_src_layer_desc() const {
5245  return query_md(query::diff_src_md, 0);
5246  }
5247 
5252  memory::desc diff_src_iter_desc() const {
5253  return query_md(query::diff_src_md, 1);
5254  }
5255 
5257  memory::desc diff_weights_layer_desc() const {
5258  return query_md(query::diff_weights_md, 0);
5259  }
5260 
5262  memory::desc diff_weights_iter_desc() const {
5263  return query_md(query::diff_weights_md, 1);
5264  }
5265 
5267  memory::desc diff_bias_desc() const {
5268  return query_md(query::diff_weights_md, 2);
5269  }
5270 
5272  memory::desc diff_dst_layer_desc() const {
5273  return query_md(query::diff_dst_md, 0);
5274  }
5275 
5280  memory::desc diff_dst_iter_desc() const {
5281  return query_md(query::diff_dst_md, 1);
5282  }
5283  };
5284 
5285  gru_backward() = default;
5286 
5287  // With last iteration (with and without input src_iter)
5288  gru_backward(const primitive_desc &pd) : primitive(pd) {}
5289 };
5290 
5294 struct lbr_gru_forward : public primitive {
5295 
5297  struct desc {
5298  dnnl_rnn_desc_t data;
5299 
5317  desc(prop_kind aprop_kind, rnn_direction direction,
5318  const memory::desc &src_layer_desc,
5319  const memory::desc &src_iter_desc,
5320  const memory::desc &weights_layer_desc,
5321  const memory::desc &weights_iter_desc,
5322  const memory::desc &bias_desc,
5323  const memory::desc &dst_layer_desc,
5324  const memory::desc &dst_iter_desc,
5325  rnn_flags flags = rnn_flags::undef) {
5328  dnnl::convert_to_c(aprop_kind),
5329  dnnl::convert_to_c(direction), &src_layer_desc.data,
5330  &src_iter_desc.data, &weights_layer_desc.data,
5331  &weights_iter_desc.data, &bias_desc.data,
5332  &dst_layer_desc.data, &dst_iter_desc.data,
5333  dnnl::convert_to_c(flags)),
5334  "could not create a Linear-before-reset GRU forward "
5335  "descriptor");
5336  }
5337  };
5338 
5340  struct primitive_desc : public rnn_primitive_desc_base {
5341  primitive_desc() = default;
5342 
5343  primitive_desc(
5344  const desc &desc, const engine &e, bool allow_empty = false)
5345  : rnn_primitive_desc_base(
5346  &desc.data, nullptr, e, nullptr, allow_empty) {}
5347 
5348  primitive_desc(const desc &desc, const primitive_attr &attr,
5349  const engine &e, bool allow_empty = false)
5350  : rnn_primitive_desc_base(
5351  &desc.data, &attr, e, nullptr, allow_empty) {}
5352 
5355  primitive_desc(dnnl_primitive_desc_t pd)
5356  : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
5358  dnnl::algorithm::lbr_gru) {}
5359 
5361  memory::desc src_layer_desc() const {
5362  return query_md(query::src_md, 0);
5363  }
5364 
5369  memory::desc src_iter_desc() const {
5370  return query_md(query::src_md, 1);
5371  }
5372 
5374  memory::desc weights_layer_desc() const {
5375  return query_md(query::weights_md, 0);
5376  }
5377 
5379  memory::desc weights_iter_desc() const {
5380  return query_md(query::weights_md, 1);
5381  }
5382 
5387  memory::desc bias_desc() const {
5388  return query_md(query::weights_md, 2);
5389  }
5390 
5392  memory::desc dst_layer_desc() const {
5393  return query_md(query::dst_md, 0);
5394  }
5395 
5400  memory::desc dst_iter_desc() const {
5401  return query_md(query::dst_md, 1);
5402  }
5403 
5407  memory::desc workspace_desc() const {
5408  return query_md(query::workspace_md, 0);
5409  }
5410  };
5411 
5412  lbr_gru_forward() = default;
5413 
5414  lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
5415 };
5416 
5420 struct lbr_gru_backward : public primitive {
5421 
5423  struct desc {
5424  dnnl_rnn_desc_t data;
5425 
5441  desc(prop_kind aprop_kind, rnn_direction direction,
5442  const memory::desc &src_layer_desc,
5443  const memory::desc &src_iter_desc,
5444  const memory::desc &weights_layer_desc,
5445  const memory::desc &weights_iter_desc,
5446  const memory::desc &bias_desc,
5447  const memory::desc &dst_layer_desc,
5448  const memory::desc &dst_iter_desc,
5449  const memory::desc &diff_src_layer_desc,
5450  const memory::desc &diff_src_iter_desc,
5451  const memory::desc &diff_weights_layer_desc,
5452  const memory::desc &diff_weights_iter_desc,
5453  const memory::desc &diff_bias_desc,
5454  const memory::desc &diff_dst_layer_desc,
5455  const memory::desc &diff_dst_iter_desc,
5456  rnn_flags flags = rnn_flags::undef) {
5459  dnnl::convert_to_c(aprop_kind),
5460  dnnl::convert_to_c(direction), &src_layer_desc.data,
5461  &src_iter_desc.data, &weights_layer_desc.data,
5462  &weights_iter_desc.data, &bias_desc.data,
5463  &dst_layer_desc.data, &dst_iter_desc.data,
5464  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
5465  &diff_weights_layer_desc.data,
5466  &diff_weights_iter_desc.data, &diff_bias_desc.data,
5467  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
5468  dnnl::convert_to_c(flags)),
5469  "could not create an LBR_GRU backward descriptor");
5470  }
5471  };
5472 
5474  struct primitive_desc : public rnn_primitive_desc_base {
5475  primitive_desc() = default;
5476 
5477  primitive_desc(const desc &desc, const engine &e,
5478  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
5479  bool allow_empty = false)
5480  : rnn_primitive_desc_base(
5481  &desc.data, nullptr, e, hint_fwd_pd.get(), allow_empty) {}
5482 
5483  primitive_desc(const desc &desc, const primitive_attr &attr,
5484  const engine &e,
5485  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
5486  bool allow_empty = false)
5487  : rnn_primitive_desc_base(
5488  &desc.data, &attr, e, hint_fwd_pd.get(), allow_empty) {}
5489 
5492  primitive_desc(dnnl_primitive_desc_t pd)
5493  : rnn_primitive_desc_base(
5494  pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
5495 
5497  memory::desc src_layer_desc() const {
5498  return query_md(query::src_md, 0);
5499  }
5500 
5505  memory::desc src_iter_desc() const {
5506  return query_md(query::src_md, 1);
5507  }
5508 
5510  memory::desc weights_layer_desc() const {
5511  return query_md(query::weights_md, 0);
5512  }
5513 
5515  memory::desc weights_iter_desc() const {
5516  return query_md(query::weights_md, 1);
5517  }
5518 
5523  memory::desc bias_desc() const {
5524  return query_md(query::weights_md, 2);
5525  }
5526 
5528  memory::desc dst_layer_desc() const {
5529  return query_md(query::dst_md, 0);
5530  }
5531 
5536  memory::desc dst_iter_desc() const {
5537  return query_md(query::dst_md, 1);
5538  }
5539 
5543  memory::desc workspace_desc() const {
5544  return query_md(query::workspace_md, 0);
5545  }
5546 
5548  memory::desc diff_src_layer_desc() const {
5549  return query_md(query::diff_src_md, 0);
5550  }
5551 
5556  memory::desc diff_src_iter_desc() const {
5557  return query_md(query::diff_src_md, 1);
5558  }
5559 
5561  memory::desc diff_weights_layer_desc() const {
5562  return query_md(query::diff_weights_md, 0);
5563  }
5564 
5566  memory::desc diff_weights_iter_desc() const {
5567  return query_md(query::diff_weights_md, 1);
5568  }
5569 
5571  memory::desc diff_bias_desc() const {
5572  return query_md(query::diff_weights_md, 2);
5573  }
5574 
5576  memory::desc diff_dst_layer_desc() const {
5577  return query_md(query::diff_dst_md, 0);
5578  }
5579 
5584  memory::desc diff_dst_iter_desc() const {
5585  return query_md(query::diff_dst_md, 1);
5586  }
5587  };
5588 
5589  lbr_gru_backward() = default;
5590 
5591  lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
5592 };
5593 
5595 
5602 
5605 struct shuffle_forward : public primitive {
5606 
5608  struct desc {
5609  dnnl_shuffle_desc_t data;
5610 
5614  desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis,
5615  int group_size) {
5617  dnnl::convert_to_c(aprop_kind),
5618  &data_desc.data, axis, group_size),
5619  "could not create a shuffle forward descriptor");
5620  }
5621  };
5622 
5624  struct primitive_desc : public dnnl::primitive_desc {
5625  primitive_desc() = default;
5626 
5627  primitive_desc(const desc &desc, const engine &e,
5628  const primitive_attr &aattr = primitive_attr(),
5629  bool allow_empty = false)
5630  : dnnl::primitive_desc(
5631  &desc.data, &aattr, e, nullptr, allow_empty) {}
5632 
5635  primitive_desc(dnnl_primitive_desc_t pd)
5636  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
5637  dnnl::prop_kind::forward_training,
5638  dnnl::prop_kind::forward_inference) {}
5639 
5641  memory::desc src_desc() const { return query_md(query::src_md, 0); }
5642 
5644  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
5645  };
5646 
5647  shuffle_forward() = default;
5648 
5649  shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
5650 };
5651 
5654 struct shuffle_backward : public primitive {
5655 
5656  // Descriptor for shuffle backward propagation.
5657  struct desc {
5658  dnnl_shuffle_desc_t data;
5659 
5662  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
5664  &diff_data_desc.data, axis, group_size),
5665  "could not create a shuffle backward descriptor");
5666  }
5667  };
5668 
5669  // Primitive descriptor for shuffle backward propagation.
5670  struct primitive_desc : public dnnl::primitive_desc {
5671  primitive_desc() = default;
5672 
5673  primitive_desc(const desc &desc, const engine &e,
5674  const shuffle_forward::primitive_desc &hint_fwd_pd,
5675  const primitive_attr &aattr = primitive_attr(),
5676  bool allow_empty = false)
5677  : dnnl::primitive_desc(
5678  &desc.data, &aattr, e, hint_fwd_pd.get(), allow_empty) {}
5679 
5683  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
5684  dnnl::prop_kind::backward_data) {}
5685 
5687  memory::desc diff_src_desc() const {
5688  return query_md(query::diff_src_md, 0);
5689  }
5690 
5692  memory::desc diff_dst_desc() const {
5693  return query_md(query::diff_dst_md, 0);
5694  }
5695  };
5696 
5697  shuffle_backward() = default;
5698 
5699  shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
5700 };
5701 
5703 
5710 
5713 struct binary : public primitive {
5714 
5716  struct desc {
5717  dnnl_binary_desc_t data;
5718 
5721  desc(algorithm aalgorithm, const memory::desc &src0,
5722  const memory::desc &src1, const memory::desc &dst) {
5724  dnnl_binary_desc_init(&data, dnnl::convert_to_c(aalgorithm),
5725  &src0.data, &src1.data, &dst.data),
5726  "could not create a binary descriptor");
5727  }
5728  };
5729 
5730  struct primitive_desc : public dnnl::primitive_desc {
5731  primitive_desc() = default;
5732 
5735  const desc &desc, const engine &e, bool allow_empty = false)
5736  : dnnl::primitive_desc(
5737  &desc.data, nullptr, e, nullptr, allow_empty) {}
5738 
5742  const desc &desc, const primitive_attr &attr, const engine &e)
5743  : dnnl::primitive_desc(&desc.data, &attr, e, nullptr) {}
5744 
5748  : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
5749 
5751  memory::desc src0_desc() const { return query_md(query::src_md, 0); }
5752 
5754  memory::desc src1_desc() const { return query_md(query::src_md, 1); }
5755 
5757  memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
5758  };
5759 
5760  binary() = default;
5761 
5762  binary(const primitive_desc &pd) : primitive(pd) {}
5763 };
5764 
5766 
5768 
5770 
5771 // implementation section
5772 
5774 inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
5775  dnnl_primitive_t result;
5777  "could not create a primitive");
5778  reset(result);
5779 }
5780 
5781 inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
5782 
5783 inline void primitive::execute(
5784  stream &astream, const std::unordered_map<int, memory> &args) const {
5785  std::vector<dnnl_exec_arg_t> c_args;
5786  c_args.reserve(args.size());
5787  for (const auto &a : args)
5788  c_args.push_back({a.first, a.second.get()});
5789 
5790  error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
5791  (int)c_args.size(), c_args.data()),
5792  "could not execute a primitive");
5793 }
5795 
5796 } // namespace dnnl
5797 
5798 #endif
dnnl::query::num_of_outputs_s32
number of outputs expected
dnnl_query_num_of_inputs_s32
number of inputs expected
Definition: dnnl_types.h:1563
dnnl_nt
2D RNN statistics tensor, an alias to dnnl_ba
Definition: dnnl_types.h:345
dnnl_stream_default_order
Default order execution.
Definition: dnnl_types.h:1620
dnnl::vanilla_rnn_backward::primitive_desc::diff_src_layer_desc
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:4596
dnnl_nCw16c
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b
Definition: dnnl_types.h:456
dnnl::query::op_d
op descriptor
dnnl_runtime_error
Primitive or engine failed on execution.
Definition: dnnl_types.h:61
dnnl::query::undef
no query
dnnl_data_type_undef
Undefined data type, used for empty memory descriptors.
Definition: dnnl_types.h:69
dnnl_lbr_gru
GRU cell with linear before reset.
Definition: dnnl_types.h:717
dnnl::convolution_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:2316
dnnl::primitive_attr::set_rnn_data_qparams
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: dnnl.hpp:779
dnnl::vanilla_rnn_backward::desc::desc
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Initializes an RNN descriptor for backward propagation using prop_kind, activation,...
Definition: dnnl.hpp:4487
dnnl_primitive_desc_get_attr
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
dnnl::algorithm::eltwise_sqrt
Eltwise: square root.
dnnl_backward_bias
Backward bias propagation.
Definition: dnnl_types.h:609
dnnl_cn
2D CNN activations tensor, an alias to dnnl_ba
Definition: dnnl_types.h:341
dnnl::softmax_backward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3572
dnnl::gru_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5083
dnnl::query::pooling_d
pooling descriptor
dnnl::memory::data_type::u8
8-bit unsigned integer.
dnnl_bac
permuted 3D tensor
Definition: dnnl_types.h:197
dnnl::lrn_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3095
dnnl::query::inner_product_d
inner product descriptor
dnnl_query_scratchpad_engine
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition: dnnl_types.h:1572
dnnl::algorithm::lrn_within_channel
LRN within a single channel.
dnnl_sum_primitive_desc_create
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_mds, int n, const float *scales, const dnnl_memory_desc_t *src_mds, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
dnnl_query_inner_product_d
inner product descriptor
Definition: dnnl_types.h:1594
dnnl::memory::desc::get_size
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area.
Definition: dnnl.hpp:1477
dnnl::primitive::kind::undef
Undefined primitive.
dnnl::vanilla_rnn_backward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4571
dnnl::memory::get_sycl_buffer
cl::sycl::buffer< T, ndims > get_sycl_buffer(size_t *offset=nullptr) const
Returns the underlying SYCL buffer object.
Definition: dnnl.hpp:1630
dnnl::lstm_backward::primitive_desc::diff_dst_iter_c_desc
memory::desc diff_dst_iter_c_desc() const
Queries diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4977
dnnl::layer_normalization_backward::primitive_desc::variance_desc
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:4006
dnnl::vanilla_rnn_backward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:4584
dnnl_sum
A sum primitive.
Definition: dnnl_types.h:624
dnnl::gru_backward::primitive_desc::diff_dst_iter_desc
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:5279
dnnl_aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:244
dnnl::error::what
const char * what() const noexcept override
Returns the explanatory string.
Definition: dnnl.hpp:68
dnnl::gru_backward::primitive_desc::diff_weights_iter_desc
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:5261
dnnl::primitive::kind::shuffle
A shuffle primitive.
dnnl::engine::kind::any
An unspecified engine.
dnnl_abcdef
plain 6D tensor
Definition: dnnl_types.h:187
dnnl_nc
2D CNN activations tensor, an alias to dnnl_ab
Definition: dnnl_types.h:339
dnnl::algorithm::convolution_direct
Direct convolution.
dnnl_pooling
A pooling primitive.
Definition: dnnl_types.h:634
dnnl_convolution_forward_desc_init
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
dnnl::pooling_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3244
dnnl::post_ops::len
int len() const
Returns the length of post operations.
Definition: dnnl.hpp:608
dnnl_query_t
dnnl_query_t
Primitive descriptor query specification.
Definition: dnnl_types.h:1557
dnnl_f32
32-bit/single-precision floating point.
Definition: dnnl_types.h:75
dnnl_io
2D CNN weights tensor, an alias to dnnl_ba
Definition: dnnl_types.h:364
dnnl::memory::format_tag::cdeba
permuted 5D tensor
dnnl_s32
32-bit signed integer.
Definition: dnnl_types.h:77
dnnl_memory
dnnl_post_ops_len
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
dnnl_primitive_desc_iterator
An opaque structure to describe a primitive descriptor iterator.
dnnl_inner_product_forward_desc_init
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...
dnnl_batch_normalization_forward_desc_init
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind (possi...
dnnl::error::error
error(dnnl_status_t astatus, const char *amessage)
Constructs an error instance.
Definition: dnnl.hpp:64
dnnl::algorithm::eltwise_elu
Eltwise: parametric exponential linear unit (elu)
dnnl_vanilla_rnn_forward_desc_init
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes an RNN descriptor rnn_desc for forward propagation using prop_kind, activation,...
dnnl_query_scratchpad_md
scratchpad memory desc
Definition: dnnl_types.h:1608
dnnl::lrn_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3158
dnnl_eltwise_abs
Eltwise: abs.
Definition: dnnl_types.h:673
dnnl::memory::desc::reshape
desc reshape(const dims &adims)
Constructs a memory descriptor by reshaping existing one.
Definition: dnnl.hpp:1467
dnnl_query_convolution_d
convolution descriptor
Definition: dnnl_types.h:1585
dnnl_primitive_attr_get_output_scales
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and a pointer to a constant floating point array of output ...
dnnl::prop_kind::backward_data
Backward data propagation.
dnnl::gru_forward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5070
dnnl_forward_training
Forward data propagation (training mode).
Definition: dnnl_types.h:593
dnnl::pooling_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3314
dnnl_cpu
CPU engine.
Definition: dnnl_types.h:1324
dnnl_cdeba
permuted 5D tensor
Definition: dnnl_types.h:204
dnnl::inner_product_backward_weights::primitive_desc::diff_weights_desc
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:4266
dnnl_s8
8-bit signed integer.
Definition: dnnl_types.h:79
dnnl_nChw8c
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b
Definition: dnnl_types.h:453
dnnl_abcd
plain 4D tensor
Definition: dnnl_types.h:185
dnnl::lstm_backward::primitive_desc::diff_weights_layer_desc
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:4949
dnnl::gru_forward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5057
dnnl::post_ops::post_ops
post_ops()
Creates an empty sequence of post operations.
Definition: dnnl.hpp:600
dnnl_wio
3D CNN weights tensor, an alias to dnnl_cba
Definition: dnnl_types.h:370
dnnl::gru_backward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5205
dnnl_stream_destroy
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
dnnl_ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition: dnnl_types.h:249
dnnl::primitive::kind::inner_product
An inner product primitive.
dnnl::handle_traits
A class that provides the destructor for an DNNL C handle.
Definition: dnnl.hpp:82
dnnl::memory::desc::data
dnnl_memory_desc_t data
The underlying C API data structure.
Definition: dnnl.hpp:1414
dnnl::lstm_forward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4758
dnnl_engine
An opaque structure to describe an engine.
dnnl::layer_normalization_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:4024
dnnl_aBcde16b
5D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:259
dnnl_acb
permuted 3D tensor
Definition: dnnl_types.h:192
dnnl::vanilla_rnn_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4451
dnnl::prop_kind::backward_bias
Backward bias propagation.
dnnl::eltwise_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3457
dnnl::stream::wait
stream & wait()
Waits for all primitives in the stream to finish.
Definition: dnnl.hpp:1033
dnnl::query::convolution_d
convolution descriptor
dnnl_query_binary_d
binary descriptor
Definition: dnnl_types.h:1597
dnnl::gru_forward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5075
dnnl::engine::get_sycl_device
cl::sycl::device DNNL_API get_sycl_device() const
Returns the underlying SYCL device object.
dnnl_iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition: dnnl_types.h:59
dnnl::convolution_backward_data::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:2410
dnnl::scratchpad_mode
scratchpad_mode
Scratchpad mode.
Definition: dnnl.hpp:263
dnnl_query_primitive_kind
primitive kind
Definition: dnnl_types.h:1561
dnnl::lstm_backward::primitive_desc::diff_dst_iter_desc
memory::desc diff_dst_iter_desc() const
Queries diff destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4972
dnnl::lbr_gru_backward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5527
dnnl::memory::format_tag::acbde
permuted 5D tensor
dnnl::lstm_backward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4893
dnnl_deconvolution_winograd
Winograd deconvolution.
Definition: dnnl_types.h:663
dnnl_engine_kind_t
dnnl_engine_kind_t
Kinds of engines.
Definition: dnnl_types.h:1320
dnnl_convolution_desc_t
A descriptor of a convolution operation.
Definition: dnnl_types.h:956
dnnl::convolution_backward_data::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2415
dnnl_eltwise_soft_relu
Eltwise: soft_relu.
Definition: dnnl_types.h:681
dnnl::convolution_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2300
DNNL_MAX_NDIMS
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition: dnnl_types.h:774
dnnl_status_t
dnnl_status_t
Status values returned by the library functions.
Definition: dnnl_types.h:49
dnnl::batch_normalization_backward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3760
dnnl::pooling_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3321
dnnl_memory_desc_init_submemory
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory_desc for a given parent_memory_desc, with dims sizes and offsets.
dnnl::deconvolution_backward_weights::primitive_desc::diff_weights_desc
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:3012
dnnl_inner_product
An inner product primitive.
Definition: dnnl_types.h:642
dnnl_primitive_attr_destroy
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Deletes an attr.
dnnl_aBcde8b
5D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:275
dnnl::query::dst_md
destination memory desc
dnnl_memory_get_memory_desc
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns a memory_desc associated with memory.
dnnl::primitive::kind::deconvolution
A deconvolution primitive.
dnnl_concat
A (out-of-place) concat primitive.
Definition: dnnl_types.h:622
dnnl_ihwo
4D CNN weights tensor, an alias to dnnl_bcda
Definition: dnnl_types.h:380
dnnl_rnn_flags_t
dnnl_rnn_flags_t
Flags for RNN cell.
Definition: dnnl_types.h:1216
dnnl::memory::data_type::f32
32-bit/single-precision floating point.
dnnl::query::layer_normalization_d
layer normalization descriptor
dnnl::algorithm::pooling_avg
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_include_padding.
dnnl_rnn_desc_t
A descriptor for an RNN operation.
Definition: dnnl_types.h:1234
dnnl::lstm_forward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4722
dnnl_convolution_winograd
Winograd convolution.
Definition: dnnl_types.h:657
dnnl::query::engine
execution engine
dnnl_engine_create_ocl
dnnl_status_t DNNL_API dnnl_engine_create_ocl(dnnl_engine_t *engine, dnnl_engine_kind_t kind, cl_device_id device, cl_context context)
Creates an engine of particular kind associated with a given OpenCL device and context objects.
dnnl::engine::get_count
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: dnnl.hpp:850
dnnl_layer_normalization
A layer normalization primitive.
Definition: dnnl_types.h:640
dnnl::memory::format_kind::packed
Packed weights format used in RNN.
dnnl::gru_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5103
dnnl_softmax_backward_desc_init
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc.
dnnl::query::deconvolution_d
deconvolution descriptor
dnnl::primitive_desc::next_impl
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: dnnl.hpp:2127
dnnl_abdec
permuted 5D tensor
Definition: dnnl_types.h:191
dnnl::lbr_gru_backward::primitive_desc::diff_dst_iter_desc
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:5583
dnnl_eltwise_elu
Eltwise: parametric exponential linear unit (elu)
Definition: dnnl_types.h:669
dnnl::layer_normalization_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3896
dnnl_inner_product_desc_t
A descriptor of an inner product operation.
Definition: dnnl_types.h:1187
dnnl::memory::set_ocl_mem_object
void set_ocl_mem_object(cl_mem mem_object)
Sets the OpenCL memory object mem_object associated with the memory.
Definition: dnnl.hpp:1618
dnnl::lstm_forward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4750
dnnl::pooling_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3249
dnnl::gru_backward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an GRU descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:5137
dnnl::gru_backward::primitive_desc::diff_src_layer_desc
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:5243
dnnl::query::weights_md
weights memory descriptor desc
dnnl_convolution_backward_data_desc_init
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
dnnl::memory::format_tag::nhwc
4D CNN activations tensor, an alias to dnnl::memory::format_tag::acdb
dnnl::eltwise_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3396
dnnl_primitive_desc_clone
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
dnnl_oihw
4D CNN weights tensor, an alias to dnnl_abcd
Definition: dnnl_types.h:374
dnnl_pooling_max
Max pooling.
Definition: dnnl_types.h:694
dnnl::lbr_gru_forward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5378
dnnl_binary_add
Binary add.
Definition: dnnl_types.h:719
dnnl::lstm_backward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4888
dnnl_gru_forward_desc_init
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a GRU descriptor rnn_desc for forward propagation using prop_kind, direction,...
dnnl::handle
A class for wrapping an DNNL handle.
Definition: dnnl.hpp:98
dnnl::engine::kind
kind
Kinds of engines.
Definition: dnnl.hpp:836
dnnl::lbr_gru_backward::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:5570
dnnl::batch_normalization_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3779
dnnl::vanilla_rnn_forward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4418
dnnl::lstm_backward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4870
dnnl::batch_normalization_forward::primitive_desc::variance_desc
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3681
dnnl::memory::format_tag::undef
Undefined memory format tag.
dnnl_dilated_convolution_backward_weights_desc_init
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
dnnl::lbr_gru_forward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5360
dnnl_a
plain 1D tensor
Definition: dnnl_types.h:182
dnnl_query_diff_dst_md
destination grad. memory desc
Definition: dnnl_types.h:1606
dnnl_ab
plain 2D tensor
Definition: dnnl_types.h:183
dnnl::binary::desc::desc
desc(algorithm aalgorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Initializes a binary descriptor using algorithm, memory descriptors src0_desc, src1_desc and dst_desc...
Definition: dnnl.hpp:5720
dnnl::normalization_flags::fuse_norm_relu
Fuse with ReLU.
dnnl_softmax_forward_desc_init
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible values are dnnl_forward_...
dnnl_query_undef
no query
Definition: dnnl_types.h:1558
dnnl::inner_product_backward_data::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:4194
dnnl::memory::format_tag::ab
plain 2D tensor
dnnl::primitive_desc_base::get_primitive_attr
primitive_attr get_primitive_attr() const
Returns the attributes.
Definition: dnnl.hpp:1774
dnnl::gru_backward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5218
dnnl_eltwise_bounded_relu
Eltwise: bounded_relu.
Definition: dnnl_types.h:679
dnnl_primitive_attr_clone
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
dnnl::algorithm::lrn_across_channels
Local response normalization (LRN) across multiple channels.
dnnl_undefined_primitive
Undefined primitive.
Definition: dnnl_types.h:616
dnnl::deconvolution_backward_data::desc::desc
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution backward propagation using aalgorithm, memory descriptors,...
Definition: dnnl.hpp:2782
dnnl::primitive_attr::set_scratchpad_mode
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition: dnnl.hpp:712
dnnl::memory::format_tag::acb
permuted 3D tensor
dnnl::pooling_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3309
dnnl::vanilla_rnn_forward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:4413
dnnl_ndhwc
5D CNN activations tensor, an alias to dnnl_acdeb
Definition: dnnl_types.h:359
dnnl::gru_backward::primitive_desc::diff_dst_layer_desc
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:5271
dnnl::gru_forward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes a GRU descriptor for forward propagation using prop_kind, direction, and memory descripto...
Definition: dnnl.hpp:5014
dnnl_rnn_direction_t
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition: dnnl_types.h:1219
dnnl::batch_normalization_forward::primitive_desc::mean_desc
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3678
dnnl::lrn_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3092
dnnl::primitive_desc_base::scratchpad_engine
engine scratchpad_engine() const
Returns the engine that owns the scratchpad memory.
Definition: dnnl.hpp:1763
dnnl::primitive::kind::softmax
A softmax primitive.
dnnl::vanilla_rnn_backward::primitive_desc::diff_src_iter_desc
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:4604
dnnl::vanilla_rnn_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4591
dnnl_memory_desc_get_size
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size (in bytes) that is required for given memory_desc.
dnnl_query_workspace_md
workspace memory desc
Definition: dnnl_types.h:1607
dnnl::query::scratchpad_engine
scratchpad engine
dnnl::handle::get
T get(bool allow_emtpy=false) const
Returns the value of the underlying C handle.
Definition: dnnl.hpp:137
dnnl::algorithm::eltwise_bounded_relu
Eltwise: bounded_relu.
dnnl::vanilla_rnn_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4431
dnnl_query_op_d
op descriptor
Definition: dnnl_types.h:1584
dnnl::lstm_forward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LSTM descriptor for forward propagation using prop_kind, direction,...
Definition: dnnl.hpp:4668
dnnl::query::scratchpad_md
scratchpad memory desc
dnnl_oidhw
5D CNN weights tensor, an alias to dnnl_abcde
Definition: dnnl_types.h:384
dnnl_pooling_desc_t
A descriptor of a pooling operation.
Definition: dnnl_types.h:1068
dnnl_giohw
5D CNN weights tensor (incl. groups), an alias to dnnl_acbde
Definition: dnnl_types.h:399
dnnl_query_batch_normalization_d
batch normalization descriptor
Definition: dnnl_types.h:1592
dnnl::convolution_backward_weights::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2567
dnnl::deconvolution_backward_weights::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3009
dnnl_dilated_convolution_backward_data_desc_init
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
dnnl::engine::get_kind
kind get_kind() const
Returns the kind of the engine.
Definition: dnnl.hpp:905
dnnl::vanilla_rnn_backward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4576
dnnl_lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition: dnnl_types.h:701
dnnl::query::diff_dst_md
destination grad. memory desc
dnnl::primitive_desc::primitive_desc
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &e, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Creates a primitive descriptor from given op_desc, attr, engine, and optionally a hint primitive desc...
Definition: dnnl.hpp:2107
dnnl_primitive_desc_iterator_create
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
dnnl::primitive::kind::pooling
A pooling primitive.
dnnl_query_impl_info_str
for creating scratchpad memory
Definition: dnnl_types.h:1575
dnnl_format_tag_undef
Undefined memory format tag.
Definition: dnnl_types.h:171
dnnl_shuffle_backward_desc_init
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc,...
dnnl::prop_kind::forward_inference
Forward data propagation (inference mode).
dnnl::vanilla_rnn_forward::desc::desc
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Initializes an RNN descriptor for forward propagation using prop_kind, activation,...
Definition: dnnl.hpp:4360
dnnl::lrn_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3163
dnnl::algorithm::vanilla_lstm
LSTM cell.
dnnl_dilated_convolution_forward_desc_init
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
dnnl_batch_normalization
A batch normalization primitive.
Definition: dnnl_types.h:638
dnnl::engine::get_ocl_device
cl_device_id get_ocl_device() const
Returns the OpenCL device associated with the engine.
Definition: dnnl.hpp:922
dnnl_rnn
A rnn primitive.
Definition: dnnl_types.h:644
dnnl::inner_product_backward_data::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:4184
dnnl::gru_backward::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:5266
dnnl_hwio
4D CNN weights tensor, an alias to dnnl_cdba
Definition: dnnl_types.h:376
dnnl::query::src_md
source memory desc
dnnl::memory::data_type::f16
16-bit/half-precision floating point.
dnnl::algorithm
algorithm
Kinds of algorithms.
Definition: dnnl.hpp:306
dnnl_pooling_avg_include_padding
Average pooling include padding.
Definition: dnnl_types.h:696
dnnl::gru_forward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5096
dnnl::memory::set_sycl_buffer
void set_sycl_buffer(cl::sycl::buffer< T, ndims > &buf)
Sets the underlying buffer to the given SYCL buffer.
Definition: dnnl.hpp:1654
dnnl::memory::format_kind::undef
Undefined memory format kind, used for empty memory descriptors.
dnnl_primitive_desc_iterator_destroy
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
dnnl::lbr_gru_forward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5373
dnnl::memory::data_type::bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
dnnl_cdba
permuted 4D tensor
Definition: dnnl_types.h:203
dnnl::lbr_gru_forward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5399
dnnl::inner_product_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:4113
dnnl::memory::format_kind::wino
Weights format used in 8bit Winograd convolution.
dnnl_success
The operation was successful.
Definition: dnnl_types.h:51
dnnl_dilated_deconvolution_forward_desc_init
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
dnnl_lrn_forward_desc_init
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are dnnl_forward_tra...
dnnl_use_scaleshift
Use scale and shift parameters.
Definition: dnnl_types.h:750
dnnl_f16
16-bit/half-precision floating point.
Definition: dnnl_types.h:71
dnnl::query::reorder_src_engine
reorder source engine
dnnl_aBcd4b
4D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:236
dnnl::vanilla_rnn_backward::primitive_desc::diff_weights_iter_desc
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:4614
dnnl::layer_normalization_backward::primitive_desc::mean_desc
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:4003
dnnl::gru_backward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5223
dnnl::pooling_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3241
dnnl::primitive::kind::eltwise
An element-wise primitive.
dnnl::algorithm::deconvolution_winograd
Winograd deconvolution.
dnnl_eltwise_gelu
Eltwise: gelu.
Definition: dnnl_types.h:690
dnnl::memory::format_tag::abcde
plain 5D tensor
dnnl::algorithm::convolution_winograd
Winograd convolution.
dnnl_invalid_arguments
The operation failed because of incorrect function arguments.
Definition: dnnl_types.h:55
dnnl_backward
Backward propagation (with respect to all parameters).
Definition: dnnl_types.h:603
dnnl_batch_normalization_backward_desc_init
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
dnnl_nwc
3D CNN activations tensor, an alias to dnnl_acb
Definition: dnnl_types.h:349
dnnl::memory::format_tag::abdec
permuted 5D tensor
dnnl::batch_normalization_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3660
dnnl::inner_product_backward_data::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:4189
dnnl_softmax_desc_t
A descriptor of a Softmax operation.
Definition: dnnl_types.h:1052
dnnl::prop_kind::undef
Undefined propagation kind.
dnnl::pooling_forward::desc::desc
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a pooling descriptor for forward propagation using aprop_kind (possible values are dnnl::...
Definition: dnnl.hpp:3202
dnnl_engine_get_ocl_device
dnnl_status_t DNNL_API dnnl_engine_get_ocl_device(dnnl_engine_t engine, cl_device_id *device)
Returns an OpenCL device associated with an engine.
dnnl_deconvolution_backward_weights_desc_init
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
dnnl_deconvolution_direct
Direct deconvolution.
Definition: dnnl_types.h:661
dnnl::stream::flags::default_order
Default order execution.
dnnl::pooling_backward::desc::desc
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a pooling descriptor for backward propagation using aalgorithm, memory descriptors,...
Definition: dnnl.hpp:3268
dnnl::memory::data_type
data_type
Data type specification.
Definition: dnnl.hpp:1081
dnnl_primitive_desc_iterator_next
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
dnnl_query_diff_src_md
source gradient memory desc
Definition: dnnl_types.h:1602
dnnl_primitive_desc_query
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries primitive descriptor.
dnnl::engine::get_ocl_context
cl_context get_ocl_context() const
Returns the OpenCL context associated with the engine.
Definition: dnnl.hpp:914
dnnl_eltwise_swish
Eltwise: swish.
Definition: dnnl_types.h:692
dnnl::batch_normalization_backward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3776
dnnl::prop_kind::forward_scoring
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
dnnl::lbr_gru_backward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LBR_GRU descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:5440
dnnl_memory_destroy
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Deletes a memory.
dnnl_primitive_desc
An opaque structure to describe a primitive descriptor.
dnnl::memory::format_tag::chwn
4D CNN activations tensor, an alias to dnnl::memory::format_tag::bcda
dnnl_stream_out_of_order
Out-of-order execution.
Definition: dnnl_types.h:1624
dnnl_deconvolution
A deconvolution primitive.
Definition: dnnl_types.h:628
dnnl::primitive::kind::sum
A sum primitive.
dnnl_primitive_desc_destroy
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
dnnl_query_src_md
source memory desc
Definition: dnnl_types.h:1601
dnnl_inner_product_backward_data_desc_init
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
dnnl::lstm_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4770
dnnl_backward_weights
Backward weights propagation.
Definition: dnnl_types.h:607
dnnl::query::memory_consumption_s64
memory consumption (bytes)
dnnl::lstm_forward::primitive_desc::dst_iter_c_desc
memory::desc dst_iter_c_desc() const
Queries destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4763
dnnl_post_ops_get_params_sum
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
dnnl::lstm_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4926
dnnl::layer_normalization_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:4019
dnnl::memory::data_type::s32
32-bit signed integer.
dnnl_lrn
An LRN primitive.
Definition: dnnl_types.h:636
dnnl_query_engine
execution engine
Definition: dnnl_types.h:1560
dnnl::memory::get_desc
desc get_desc() const
Returns the descriptor of the memory.
Definition: dnnl.hpp:1542
dnnl::shuffle_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:5643
dnnl_iohw
4D CNN weights tensor, an alias to dnnl_bacd
Definition: dnnl_types.h:382
dnnl::layer_normalization_forward::primitive_desc::mean_desc
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3899
dnnl_stream_get_ocl_command_queue
dnnl_status_t DNNL_API dnnl_stream_get_ocl_command_queue(dnnl_stream_t stream, cl_command_queue *queue)
Returns the OpenCL command queue associated with an execution stream.
dnnl::algorithm::convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
dnnl_primitive_attr_set_rnn_weights_qparams
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
dnnl_query_reorder_src_engine
source engine
Definition: dnnl_types.h:1577
dnnl_lbr_gru_backward_desc_init
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes an LBR GRU descriptor rnn_desc for backward propagation using prop_kind,...
dnnl::memory::format_tag
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1117
dnnl::memory::unmap_data
void unmap_data(void *mapped_ptr) const
Unmaps the previously mapped data for the memory.
Definition: dnnl.hpp:1603
dnnl_memory_get_data_handle
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
For a memory, returns the data handle.
dnnl::gru_forward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:5065
dnnl::algorithm::pooling_avg_exclude_padding
Average pooling exclude padding.
dnnl_memory_desc_equal
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
dnnl::primitive_attr::set_post_ops
void set_post_ops(post_ops ops)
Sets post_ops for future use.
Definition: dnnl.hpp:766
dnnl_backward_data
Backward data propagation.
Definition: dnnl_types.h:605
dnnl::query::num_of_inputs_s32
number of inputs expected
dnnl_query_weights_md
weights memory descriptor desc
Definition: dnnl_types.h:1603
dnnl::convolution_forward::desc::desc
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution forward propagation without bias using aprop_kind (possible ...
Definition: dnnl.hpp:2178
dnnl::handle::handle
handle()=default
Empty constructor.
dnnl::batch_normalization_backward::primitive_desc::mean_desc
memory::desc mean_desc() const
Queries mean memory descriptor.
Definition: dnnl.hpp:3763
dnnl_eltwise_linear
Eltwise: linear.
Definition: dnnl_types.h:677
dnnl::deconvolution_backward_data::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2857
dnnl::query::time_estimate_f64
runtime estimation (seconds), unimplemented
dnnl::deconvolution_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:2740
dnnl::lbr_gru_backward::primitive_desc::diff_dst_layer_desc
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:5575
dnnl_memory_desc_init_by_tag
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and format tag.
dnnl_ldnc
4D RNN states tensor in the format (num_layers, num_directions, batch, state channels).
Definition: dnnl_types.h:409
dnnl_query_prop_kind
propagation kind
Definition: dnnl_types.h:1580
dnnl::algorithm::eltwise_relu
Eltwise: ReLU.
dnnl::vanilla_rnn_backward::primitive_desc::diff_weights_layer_desc
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:4609
dnnl_primitive
dnnl::deconvolution_backward_data::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2862
dnnl::query
query
Primitive descriptor query specification.
Definition: dnnl.hpp:493
dnnl::handle::reset
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: dnnl.hpp:132
dnnl::lrn_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3100
dnnl_aBcd16b
4D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:230
dnnl_post_ops_get_params_eltwise
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops.
dnnl_query_num_of_outputs_s32
number of outputs expected
Definition: dnnl_types.h:1564
dnnl_engine_destroy
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
dnnl::query::binary_d
binary descriptor
dnnl::query::eltwise_d
eltwise descriptor
dnnl::stream::get_sycl_queue
cl::sycl::queue DNNL_API get_sycl_queue() const
Returns the underlying SYCL queue object.
dnnl::vanilla_rnn_backward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4545
dnnl_format_kind_rnn_packed
Packed weights format used in RNN.
Definition: dnnl_types.h:98
dnnl_memory_create
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory for given memory_desc and engine.
dnnl::algorithm::eltwise_exp
Eltwise: exponent.
dnnl::memory::format_tag::acdb
permuted 4D tensor
dnnl_memory_map_data
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
For a memory, maps the data of the memory to mapped_ptr.
dnnl::layer_normalization_forward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Initializes a layer normalization descriptor for forward propagation using prop_kind (possible values...
Definition: dnnl.hpp:3843
dnnl::memory::format_tag::aBcd8b
4D tensor blocked by 2nd dimension with block size 8
dnnl::gru_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5238
dnnl::eltwise_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3393
dnnl::memory::format_tag::cba
permuted 3D tensor
dnnl_primitive_attr_get_scratchpad_mode
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the scratchpad mode set in the attribute attr.
dnnl::query::workspace_md
workspace memory desc
dnnl::softmax_forward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Initializes a softmax descriptor for forward propagation using prop_kind (possible values are dnnl::f...
Definition: dnnl.hpp:3487
dnnl_binary_desc_init
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a binary descriptor binary_desc, alg_kind (possible values are dnnl_binary_add and dnnl_b...
dnnl::vanilla_rnn_backward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:4553
dnnl::gru_backward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5210
dnnl_engine_create
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
dnnl::lstm_forward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4714
dnnl::eltwise_backward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3449
dnnl_format_kind_undef
Undefined memory format kind, used for empty memory descriptors.
Definition: dnnl_types.h:87
dnnl_use_global_stats
Use global statistics.
Definition: dnnl_types.h:737
dnnl_acbde
permuted 5D tensor
Definition: dnnl_types.h:193
dnnl::lbr_gru_forward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LBR GRU descriptor for forward propagation using prop_kind, direction,...
Definition: dnnl.hpp:5316
dnnl::algorithm::eltwise_square
Eltwise: square.
dnnl::engine::kind::gpu
GPU engine.
dnnl_primitive_attr_set_scratchpad_mode
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets scratchpad mode.
dnnl::softmax_backward::desc::desc
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Initializes a softmax descriptor for backward propagation using memory descriptors diff_desc and data...
Definition: dnnl.hpp:3539
dnnl_lrn_desc_t
A descriptor of a Local Response Normalization (LRN) operation.
Definition: dnnl_types.h:1101
dnnl_query_deconvolution_d
deconvolution descriptor
Definition: dnnl_types.h:1586
dnnl_deconvolution_forward_desc_init
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
dnnl_shuffle_desc_t
A descriptor of a shuffle operation.
Definition: dnnl_types.h:999
dnnl::lstm_backward::primitive_desc::src_iter_c_desc
memory::desc src_iter_c_desc() const
Queries source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4883
dnnl_query_rnn_d
rnn descriptor
Definition: dnnl_types.h:1595
dnnl::algorithm::lbr_gru
GRU cell with linear before reset.
dnnl_normalization_flags_t
dnnl_normalization_flags_t
Flags for batch normalization primitive.
Definition: dnnl_types.h:725
dnnl_batch_normalization_desc_t
A descriptor of a Batch Normalization operation.
Definition: dnnl_types.h:1127
dnnl_query_memory_consumption_s64
memory consumption – extra
Definition: dnnl_types.h:1567
dnnl::vanilla_rnn_backward::primitive_desc::diff_dst_iter_desc
memory::desc diff_dst_iter_desc() const
Queries diff destination iteration memory descriptor.
Definition: dnnl.hpp:4632
dnnl_primitive_attr_set_rnn_data_qparams
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
dnnl::query::shuffle_d
shuffle descriptor
dnnl_any_engine
An unspecified engine.
Definition: dnnl_types.h:1322
dnnl_eltwise_desc_t
A descriptor of a element-wise operation.
Definition: dnnl_types.h:1016
dnnl::query::impl_info_str
implementation name
dnnl::layer_normalization_forward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3891
dnnl::vanilla_rnn_forward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:4444
dnnl::convolution_backward_weights::primitive_desc::diff_weights_desc
memory::desc diff_weights_desc() const
Queries diff weights memory descriptor.
Definition: dnnl.hpp:2570
dnnl::memory::data_type::undef
Undefined data type, used for empty memory descriptors.
dnnl::convolution_backward_weights::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2580
dnnl_gpu
GPU engine.
Definition: dnnl_types.h:1326
dnnl_lstm_backward_desc_init
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes an LSTM descriptor rnn_desc for backward propagation using prop_kind, direction,...
dnnl_query_shuffle_d
shuffle descriptor
Definition: dnnl_types.h:1587
dnnl_u8
8-bit unsigned integer.
Definition: dnnl_types.h:81
dnnl_ldgoi
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels,...
Definition: dnnl_types.h:423
dnnl::layer_normalization_backward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:4011
dnnl::query::primitive_kind
primitive kind
dnnl_memory_desc_t::ndims
int ndims
Number of dimensions.
Definition: dnnl_types.h:885
dnnl::algorithm::eltwise_swish
Eltwise: x*sigmoid(a*x)
dnnl_layer_normalization_desc_t
A descriptor of a Layer Normalization operation.
Definition: dnnl_types.h:1155
dnnl::algorithm::binary_mul
Binary mul.
dnnl_memory_get_engine
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns an engine associated with memory.
dnnl::primitive::get_primitive_desc
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: dnnl.hpp:249
dnnl::batch_normalization_forward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &src_desc, float epsilon, normalization_flags flags)
Initializes a batch normalization descriptor for forward propagation using prop_kind (possible values...
Definition: dnnl.hpp:3626
dnnl::deconvolution_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:2751
dnnl_tn
2D RNN statistics tensor, an alias to dnnl_ab
Definition: dnnl_types.h:343
dnnl_engine_get_ocl_context
dnnl_status_t DNNL_API dnnl_engine_get_ocl_context(dnnl_engine_t engine, cl_context *context)
Returns an OpenCL context associated with an engine.
dnnl_nCdhw16c
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b
Definition: dnnl_types.h:438
dnnl::lbr_gru_backward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5535
dnnl::error
DNNL exception class.
Definition: dnnl.hpp:56
dnnl::layer_normalization_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3907
dnnl_vanilla_rnn
RNN cell.
Definition: dnnl_types.h:705
dnnl_binary_mul
Binary mul.
Definition: dnnl_types.h:721
dnnl::memory::format_tag::abc
plain 3D tensor
dnnl_primitive_desc_query_md
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for memory descriptor.
dnnl::layer_normalization_forward::primitive_desc::variance_desc
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3902
dnnl_vanilla_rnn_backward_desc_init
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes an RNN descriptor rnn_desc for backward propagation using prop_kind, activation,...
dnnl::gru_backward::primitive_desc::diff_weights_layer_desc
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:5256
dnnl_stream_wait
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish.
dnnl::post_ops::get_params_eltwise
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Gets the eltwise parameters of the post operation with index index.
Definition: dnnl.hpp:665
dnnl_aBcdef4b
6D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:292
dnnl::primitive::kind::lrn
An LRN primitive.
dnnl_memory_set_data_handle
dnnl_status_t DNNL_API dnnl_memory_set_data_handle(dnnl_memory_t memory, void *handle)
For a memory, sets the data handle.
dnnl_alg_kind_t
dnnl_alg_kind_t
Kinds of algorithms.
Definition: dnnl_types.h:652
dnnl::inner_product_backward_weights::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:4276
dnnl_nchw
4D CNN activations tensor, an alias to dnnl_abcd
Definition: dnnl_types.h:351
dnnl::lstm_backward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4914
dnnl::prop_kind::backward_weights
Backward weights propagation.
dnnl_reorder_primitive_desc_create
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_md, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_md, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Initializes a reorder_primitive_desc using the description of the source (src_engine and src_md) and ...
dnnl::primitive_attr::get_output_scales
void get_output_scales(int &mask, std::vector< float > &scales) const
Gets correspondence scale mask and a constant floating point vector of output scales previously set b...
Definition: dnnl.hpp:720
dnnl::layer_normalization_backward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:4000
dnnl::algorithm::eltwise_gelu
Eltwise: gelu.
dnnl_engine_get_kind
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
dnnl::lstm_backward::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4959
dnnl::deconvolution_backward_data::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source gradient memory descriptor.
Definition: dnnl.hpp:2852
dnnl::inner_product_forward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:4116
dnnl_eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: dnnl_types.h:667
dnnl_lrn_within_channel
LRN within a single channel.
Definition: dnnl_types.h:703
dnnl_inner_product_backward_weights_desc_init
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
dnnl::primitive_desc
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition: dnnl.hpp:2097
dnnl_layer_normalization_backward_desc_init
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a layer normalization descriptor lnrm_desc for backward propagation with respect to data ...
dnnl_aBcde4b
5D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:266
dnnl::memory::desc::submemory_desc
desc submemory_desc(const dims &adims, const dims &offsets)
Constructs a sub-memory descriptor.
Definition: dnnl.hpp:1458
dnnl::softmax_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3521
dnnl_ldigo
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates,...
Definition: dnnl_types.h:416
dnnl::lbr_gru_backward::primitive_desc::diff_weights_iter_desc
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:5565
dnnl::vanilla_rnn_backward::primitive_desc::diff_dst_layer_desc
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:4624
dnnl::stream::flags
flags
Stream flags.
Definition: dnnl.hpp:979
dnnl::batch_normalization_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3786
dnnl_primitive_create
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive using a primitive_desc descriptor.
dnnl::post_ops::append_sum
void append_sum(float scale=1.)
Appends accumulation (sum) post operation.
Definition: dnnl.hpp:638
dnnl_nCw8c
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b
Definition: dnnl_types.h:462
dnnl::primitive::kind
kind
Kinds of primitives.
Definition: dnnl.hpp:193
dnnl::lrn_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3170
dnnl::primitive::kind::concat
A (out-of-place) concat primitive.
dnnl::eltwise_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3452
dnnl::layer_normalization_backward::primitive_desc::diff_weights_desc
memory::desc diff_weights_desc() const
Queries diff weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:4029
dnnl_primitive_attr
An opaque structure for primitive descriptor attributes.
dnnl::algorithm::eltwise_abs
Eltwise: abs.
dnnl::inner_product_backward_weights::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4271
dnnl_query_time_estimate_f64
runtime estimation (seconds)
Definition: dnnl_types.h:1566
dnnl_data_type_t
dnnl_data_type_t
Data type specification.
Definition: dnnl_types.h:67
dnnl_query_reorder_dst_engine
destination engine
Definition: dnnl_types.h:1578
dnnl_shuffle_forward_desc_init
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc,...
dnnl::vanilla_rnn_backward::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:4619
dnnl_unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition: dnnl_types.h:1221
dnnl_lbr_gru_forward_desc_init
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes an LBR GRU descriptor rnn_desc for forward propagation using prop_kind,...
dnnl_bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition: dnnl_types.h:1226
dnnl::lbr_gru_backward::primitive_desc::diff_weights_layer_desc
memory::desc diff_weights_layer_desc() const
Queries diff weights layer memory descriptor.
Definition: dnnl.hpp:5560
dnnl::gru_backward::primitive_desc::diff_src_iter_desc
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:5251
dnnl::algorithm::deconvolution_direct
Direct deconvolution.
dnnl::shuffle_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:5640
dnnl_ncdhw
5D CNN activations tensor, an alias to dnnl_abcde
Definition: dnnl_types.h:357
dnnl_primitive_get_primitive_desc
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
dnnl::primitive::kind::reorder
A reorder primitive.
dnnl::lbr_gru_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5542
dnnl_pooling_avg_exclude_padding
Average pooling exclude padding.
Definition: dnnl_types.h:698
dnnl_post_ops_create
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
dnnl::inner_product_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:4129
dnnl::lbr_gru_backward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:5504
dnnl_chwn
4D CNN activations tensor, an alias to dnnl_bcda
Definition: dnnl_types.h:355
dnnl::lbr_gru_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:5406
dnnl::memory::format_tag::bacd
permuted 4D tensor
dnnl_lrn_backward_desc_init
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc and dif...
dnnl::softmax_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3518
dnnl_cba
permuted 3D tensor
Definition: dnnl_types.h:202
dnnl::normalization_flags::use_scale_shift
Use scale and shift parameters.
dnnl_post_ops
An opaque structure for a chain of post operations.
dnnl::memory::format_tag::cdba
permuted 4D tensor
dnnl_shuffle
A shuffle primitive.
Definition: dnnl_types.h:620
dnnl_oiw
3D CNN weights tensor, an alias to dnnl_abc
Definition: dnnl_types.h:366
dnnl::deconvolution_backward_weights::desc::desc
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution weight update with bias using aalgorithm,...
Definition: dnnl.hpp:2890
dnnl::error::wrap_c_api
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to the C API.
Definition: dnnl.hpp:75
dnnl_post_ops_get_kind
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of post operation with index index in given post_ops.
dnnl::memory::format_tag::abcd
plain 4D tensor
dnnl_bcda
permuted 4D tensor
Definition: dnnl_types.h:200
dnnl_concat_primitive_desc_create
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_md, int n, int concat_dimension, const dnnl_memory_desc_t *src_mds, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
dnnl_eltwise_sqrt
Eltwise: square root.
Definition: dnnl_types.h:675
dnnl_eltwise
An element-wise primitive.
Definition: dnnl_types.h:630
dnnl::lstm_backward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4906
dnnl::batch_normalization_backward::primitive_desc::variance_desc
memory::desc variance_desc() const
Queries variance memory descriptor.
Definition: dnnl.hpp:3766
dnnl_abc
plain 3D tensor
Definition: dnnl_types.h:184
dnnl_query_dst_md
destination memory desc
Definition: dnnl_types.h:1605
dnnl::memory::format_tag::acdeb
permuted 5D tensor
dnnl::batch_normalization_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:3668
dnnl_acdeb
permuted 5D tensor
Definition: dnnl_types.h:195
dnnl::lrn_backward::desc::desc
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Initializes a descriptor for backward propagation using aalgorithm, memory descriptors data_desc and ...
Definition: dnnl.hpp:3122
dnnl::algorithm::vanilla_rnn
RNN cell.
dnnl::algorithm::pooling_max
Max pooling.
dnnl::primitive_desc_base::query_s64
memory::dim query_s64(query q) const
Queries the memory::dim value (same as int64_t).
Definition: dnnl.hpp:1733
dnnl::lstm_forward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4732
dnnl_goiw
4D CNN weights tensor (incl. groups), an alias to dnnl_abcd
Definition: dnnl_types.h:393
dnnl::query::lrn_d
lrn descriptor
dnnl::query::batch_normalization_d
batch normalization descriptor
dnnl::engine::get_sycl_context
cl::sycl::context DNNL_API get_sycl_context() const
Returns the underlying SYCL context object.
dnnl::primitive::kind::layer_normalization
A layer normalization primitive.
dnnl_format_tag_any
Undefined memory format tag.
Definition: dnnl_types.h:174
dnnl::query::diff_weights_md
weights grad. memory desc
dnnl::vanilla_rnn_backward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:4558
dnnl_pooling_backward_desc_init
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind,...
dnnl::lbr_gru_backward::primitive_desc::weights_layer_desc
memory::desc weights_layer_desc() const
Queries weights layer memory descriptor.
Definition: dnnl.hpp:5509
dnnl::batch_normalization_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3791
dnnl::memory::format_tag::bac
permuted 3D tensor
dnnl_memory_desc_t
Memory descriptor.
Definition: dnnl_types.h:883
dnnl_vanilla_gru
GRU cell.
Definition: dnnl_types.h:709
dnnl_convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: dnnl_types.h:659
dnnl_format_kind_wino
Weights format used in 8bit Winograd convolution.
Definition: dnnl_types.h:96
dnnl::query::diff_src_md
source gradient memory desc
dnnl::vanilla_rnn_forward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4423
dnnl::lstm_forward::primitive_desc::src_iter_c_desc
memory::desc src_iter_c_desc() const
Queries source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4727
dnnl_query_diff_weights_md
weights grad. memory desc
Definition: dnnl_types.h:1604
dnnl_dim_t
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition: dnnl_types.h:777
dnnl_eltwise_relu
Eltwise: ReLU.
Definition: dnnl_types.h:665
dnnl_goihw
5D CNN weights tensor (incl. groups), an alias to dnnl_abcde
Definition: dnnl_types.h:395
dnnl::post_ops::append_eltwise
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Appends eltwise post operation.
Definition: dnnl.hpp:658
dnnl_eltwise_forward_desc_init
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for forward propagation using prop_kind (possible values are dnnl_forward...
dnnl::primitive_attr::set_output_scales
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scales for primitive operations.
Definition: dnnl.hpp:749
dnnl::convolution_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:2311
dnnl::inner_product_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4124
dnnl_binary
A binary primitive.
Definition: dnnl_types.h:648
dnnl_memory_get_ocl_mem_object
dnnl_status_t DNNL_API dnnl_memory_get_ocl_mem_object(const_dnnl_memory_t memory, cl_mem *mem_object)
For a memory returns the OpenCL memory object associated with it.
dnnl_oi
2D CNN weights tensor, an alias to dnnl_ab
Definition: dnnl_types.h:362
dnnl::convolution_backward_weights::desc::desc
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution weight update with bias using aalgorithm,...
Definition: dnnl.hpp:2448
dnnl::deconvolution_forward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:2756
dnnl::lstm_backward::primitive_desc::diff_src_layer_desc
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:4931
dnnl::lbr_gru_backward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5522
dnnl::inner_product_backward_weights::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:4263
dnnl_aBc16b
3D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:212
dnnl.h
dnnl_convolution_backward_weights_desc_init
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
dnnl::primitive_desc_base::scratchpad_desc
memory::desc scratchpad_desc() const
Queries scratchpad memory descriptor.
Definition: dnnl.hpp:1758
dnnl::post_ops::get_params_sum
void get_params_sum(int index, float &scale) const
Gets the parameters of the accumulation (sum) post operation with index index.
Definition: dnnl.hpp:645
dnnl::primitive::kind::convolution
A convolution primitive.
dnnl::memory::format_tag::abcdef
plain 6D tensor
dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:2575
dnnl_primitive_destroy
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Deletes a primitive.
dnnl::memory::format_tag::ba
permuted 2D tensor
dnnl_softmax
A softmax primitive.
Definition: dnnl_types.h:632
dnnl_dilated_deconvolution_backward_weights_desc_init
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
dnnl_primitive_attr_get_post_ops
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns post_ops for given attr.
dnnl_ntc
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: dnnl_types.h:406
dnnl::lbr_gru_backward::primitive_desc::diff_src_layer_desc
memory::desc diff_src_layer_desc() const
Queries diff source layer memory descriptor.
Definition: dnnl.hpp:5547
dnnl::memory::format_tag::any
Placeholder memory format tag.
dnnl::batch_normalization_forward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:3673
dnnl_stream_default_flags
Default stream configuration.
Definition: dnnl_types.h:1626
dnnl_nChw16c
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b
Definition: dnnl_types.h:447
dnnl::lstm_backward::primitive_desc::diff_dst_layer_desc
memory::desc diff_dst_layer_desc() const
Queries diff destination layer memory descriptor.
Definition: dnnl.hpp:4964
dnnl::primitive_attr::get_post_ops
const post_ops get_post_ops() const
Returns post_ops previously set by set_post_ops.
Definition: dnnl.hpp:756
dnnl::vanilla_rnn_forward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:4405
dnnl_dilated_deconvolution_backward_data_desc_init
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
dnnl::gru_forward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5088
dnnl_format_tag_t
dnnl_format_tag_t
Memory format tag specification.
Definition: dnnl_types.h:169
dnnl::convolution_forward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2303
dnnl::primitive::kind::binary
A binary primitive.
dnnl_stream_create
dnnl_status_t DNNL_API dnnl_stream_create(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags)
Creates an execution stream for engine and with flags.
dnnl_primitive_attr_set_post_ops
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
dnnl_scratchpad_mode_library
The library manages scratchpad (default)
Definition: dnnl_types.h:1379
dnnl_ldgo
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: dnnl_types.h:430
dnnl::layer_normalization_backward::primitive_desc::dst_desc
memory::desc dst_desc() const
Queries destination memory descriptor.
Definition: dnnl.hpp:4016
dnnl_hwigo
5D CNN weights tensor (incl. groups), an alias to dnnl_decab
Definition: dnnl_types.h:397
dnnl::algorithm::eltwise_linear
Eltwise: linear.
dnnl::memory::get_engine
engine get_engine() const
Returns the engine of the memory.
Definition: dnnl.hpp:1550
dnnl_eltwise_logistic
Eltwise: logistic.
Definition: dnnl_types.h:683
dnnl_eltwise_backward_desc_init
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff...
dnnl::memory::format_kind
format_kind
Memory format kind.
Definition: dnnl.hpp:1099
dnnl_primitive_attr_set_output_scales
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scales for primitive operations.
dnnl::prop_kind::forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
dnnl_layer_normalization_forward_desc_init
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a layer normalization descriptor lnrm_desc for forward propagation using prop_kind (possi...
dnnl_aBc4b
3D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:216
dnnl::algorithm::eltwise_soft_relu
Eltwise: soft_relu.
dnnl::memory::desc::is_zero
bool is_zero() const
Returns true if the memory descriptor describes an empty memory.
Definition: dnnl.hpp:1480
dnnl_memory_unmap_data
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
For a memory, unmaps a mapped pointer to the data of the memory.
dnnl::memory::format_tag::bcda
permuted 4D tensor
dnnl::lstm_backward::primitive_desc::diff_weights_iter_desc
memory::desc diff_weights_iter_desc() const
Queries diff weights iteration memory descriptor.
Definition: dnnl.hpp:4954
dnnl::lstm_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4745
dnnl::primitive::kind::batch_normalization
A batch normalization primitive.
dnnl::lbr_gru_forward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:5386
dnnl::prop_kind::forward_training
Forward data propagation (training mode).
dnnl::memory::format_tag::nchw
4D CNN activations tensor, an alias to dnnl::memory::format_tag::abcd
dnnl_stream
dnnl::memory::format_kind::blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
dnnl_forward_scoring
Forward data propagation (alias for dnnl_forward_inference).
Definition: dnnl_types.h:599
dnnl::lbr_gru_forward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:5391
dnnl_binary_desc_t
A descriptor of a binary operation.
Definition: dnnl_types.h:1301
dnnl_bacd
permuted 4D tensor
Definition: dnnl_types.h:198
dnnl_aBcdef16b
6D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:288
dnnl::lbr_gru_backward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:5514
dnnl::memory::get_data_handle
void * get_data_handle() const
Returns a handle of the data contained in the memory.
Definition: dnnl.hpp:1560
dnnl_abcde
plain 5D tensor
Definition: dnnl_types.h:186
dnnl_bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition: dnnl_types.h:1229
dnnl::lbr_gru_backward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5496
dnnl::memory::format_tag::a
plain 1D tensor
dnnl_post_ops_append_sum
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
dnnl_prop_kind_t
dnnl_prop_kind_t
Kinds of propagation.
Definition: dnnl_types.h:587
dnnl::stream::get_ocl_command_queue
cl_command_queue get_ocl_command_queue() const
Returns the OpenCL command queue associated with the stream.
Definition: dnnl.hpp:1013
dnnl_convolution_direct
Direct convolution.
Definition: dnnl_types.h:655
dnnl_engine_get_count
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
dnnl_primitive_attr_create
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
dnnl_primitive_desc_iterator_fetch
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor.
dnnl_query_lrn_d
lrn descriptor
Definition: dnnl_types.h:1591
dnnl::memory::format_tag::decab
permuted 5D tensor
dnnl_forward_inference
Forward data propagation (inference mode).
Definition: dnnl_types.h:597
dnnl_primitive_kind_t
dnnl_primitive_kind_t
Kinds of primitives.
Definition: dnnl_types.h:614
dnnl::memory::get_ocl_mem_object
cl_mem get_ocl_mem_object() const
Returns the OpenCL memory object associated with the memory.
Definition: dnnl.hpp:1610
dnnl_unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition: dnnl_types.h:1223
dnnl_prop_kind_undef
Undefined propagation type.
Definition: dnnl_types.h:590
dnnl::post_ops::kind
primitive::kind kind(int index) const
Returns the kind of post operation with index index.
Definition: dnnl.hpp:611
dnnl::lrn_forward::desc::desc
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Initializes a descriptor for forward propagation using prop_kind (possible values are dnnl::forward_t...
Definition: dnnl.hpp:3057
dnnl::lbr_gru_forward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iteration memory descriptor.
Definition: dnnl.hpp:5368
dnnl::lstm_backward::primitive_desc::dst_iter_c_desc
memory::desc dst_iter_c_desc() const
Queries destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:4919
dnnl::layer_normalization_backward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Initializes a layer normalization descriptor for backward propagation with respect to data and scale-...
Definition: dnnl.hpp:3950
dnnl::batch_normalization_backward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3771
dnnl_scratchpad_mode_t
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition: dnnl_types.h:1377
dnnl::deconvolution_backward_weights::primitive_desc::diff_bias_desc
memory::desc diff_bias_desc() const
Queries diff bias memory descriptor.
Definition: dnnl.hpp:3017
dnnl::normalization_flags::use_global_stats
Use global statistics.
dnnl::lbr_gru_backward::primitive_desc::diff_src_iter_desc
memory::desc diff_src_iter_desc() const
Queries diff source iteration memory descriptor.
Definition: dnnl.hpp:5555
dnnl::memory::data_type::s8
8-bit signed integer.
dnnl::prop_kind
prop_kind
Propagation kind.
Definition: dnnl.hpp:275
dnnl::query::softmax_d
softmax descriptor
dnnl_forward
Forward data propagation (alias for dnnl_forward_training).
Definition: dnnl_types.h:601
dnnl_aBc8b
3D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:222
dnnl::algorithm::vanilla_gru
GRU cell.
dnnl_goidhw
6D CNN weights tensor (incl. groups), an alias to dnnl_abcdef
Definition: dnnl_types.h:401
dnnl_stream_create_ocl
dnnl_status_t DNNL_API dnnl_stream_create_ocl(dnnl_stream_t *stream, dnnl_engine_t engine, cl_command_queue queue)
Creates an execution stream for a given engine associated with an OpenCL command queue.
dnnl::stream::flags::in_order
In-order execution.
dnnl_reorder
A reorder primitive.
Definition: dnnl_types.h:618
dnnl::memory::map_data
T * map_data() const
Maps the data of the memory.
Definition: dnnl.hpp:1588
dnnl::convolution_backward_data::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:2420
dnnl::lstm_backward::primitive_desc::diff_src_iter_desc
memory::desc diff_src_iter_desc() const
Queries diff source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4939
dnnl_query_layer_normalization_d
layer normalization descriptor
Definition: dnnl_types.h:1593
dnnl::algorithm::eltwise_logistic
Eltwise: logistic.
dnnl_query_softmax_d
softmax descriptor
Definition: dnnl_types.h:1589
dnnl::primitive_attr::get_scratchpad_mode
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition: dnnl.hpp:703
dnnl::algorithm::eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
dnnl_eltwise_square
Eltwise: square.
Definition: dnnl_types.h:671
dnnl_ba
permuted 2D tensor
Definition: dnnl_types.h:196
dnnl_ncw
3D CNN activations tensor, an alias to dnnl_abc
Definition: dnnl_types.h:347
dnnl_acdb
permuted 4D tensor
Definition: dnnl_types.h:194
dnnl::memory::desc::desc
desc()
Constructs a zero memory descriptor.
Definition: dnnl.hpp:1417
dnnl::stream::flags::out_of_order
Out-of-order execution.
dnnl_query_pooling_d
pooling descriptor
Definition: dnnl_types.h:1590
dnnl::primitive_attr::primitive_attr
primitive_attr()
Creates default primitive attributes.
Definition: dnnl.hpp:689
dnnl::primitive::kind::rnn
A rnn primitive.
dnnl_nhwc
4D CNN activations tensor, an alias to dnnl_acdb
Definition: dnnl_types.h:353
dnnl_gru_backward_desc_init
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a GRU descriptor rnn_desc for backward propagation using prop_kind, direction,...
dnnl::scratchpad_mode::library
The library manages scratchpad (default)
dnnl_nCw4c
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b
Definition: dnnl_types.h:459
dnnl::primitive_desc_base::get_engine
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition: dnnl.hpp:1721
dnnl_scratchpad_mode_user
A user shall query and provide the scratchpad memory to primitives.
Definition: dnnl_types.h:1381
dnnl::scratchpad_mode::user
A user shall query and provide the scratchpad memory to primitives.
dnnl::gru_backward::primitive_desc::src_layer_desc
memory::desc src_layer_desc() const
Queries source layer memory descriptor.
Definition: dnnl.hpp:5192
dnnl::primitive_desc_base::impl_info_str
const char * impl_info_str() const
Returns implementation name.
Definition: dnnl.hpp:1724
dnnl::deconvolution_forward::desc::desc
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for deconvolution forward propagation with bias using prop_kind (possible va...
Definition: dnnl.hpp:2618
dnnl::lstm_forward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4737
dnnl::query::reorder_dst_engine
reorder destination engine
dnnl::batch_normalization_forward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3663
dnnl::primitive_desc_base::query_md
memory::desc query_md(query what, int idx=0) const
Queries and returns requested memory descriptor.
Definition: dnnl.hpp:1741
dnnl_memory_set_ocl_mem_object
dnnl_status_t DNNL_API dnnl_memory_set_ocl_mem_object(dnnl_memory_t memory, cl_mem mem_object)
For a memory sets the OpenCL memory object associated with it.
dnnl::normalization_flags
normalization_flags
Flags for batch normalization primitive.
Definition: dnnl.hpp:382
dnnl::batch_normalization_backward::primitive_desc::diff_weights_desc
memory::desc diff_weights_desc() const
Queries diff weights (scale and shift) memory descriptor.
Definition: dnnl.hpp:3796
dnnl::batch_normalization_backward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Initializes a batch normalization descriptor for backward propagation with respect to data and scale-...
Definition: dnnl.hpp:3722
dnnl_post_ops_append_eltwise
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha, and beta (.
dnnl_nCdhw8c
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b
Definition: dnnl_types.h:444
dnnl::deconvolution_forward::primitive_desc::weights_desc
memory::desc weights_desc() const
Queries weights memory descriptor.
Definition: dnnl.hpp:2743
dnnl_memory_desc_reshape
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes an out_memory_desc with new ndims and dims from a in_memory_desc.
dnnl_fuse_norm_relu
Fuse with ReLU.
Definition: dnnl_types.h:763
dnnl::layer_normalization_backward::primitive_desc::workspace_desc
memory::desc workspace_desc() const
Queries workspace memory descriptor.
Definition: dnnl.hpp:4036
dnnl_deconvolution_backward_data_desc_init
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
dnnl::memory::format_tag::nc
2D CNN activations tensor, an alias to dnnl::memory::format_tag::ab
dnnl::stream::flags::default_flags
Default stream configuration.
dnnl_format_kind_any
Unspecified format kind.
Definition: dnnl_types.h:90
dnnl::algorithm::binary_add
Binary add.
dnnl_eltwise_exp
Eltwise: exponent.
Definition: dnnl_types.h:685
const_dnnl_op_desc_t
const typedef void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: dnnl_types.h:953
dnnl_tnc
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: dnnl_types.h:404
dnnl_nCdhw4c
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b
Definition: dnnl_types.h:441
dnnl_query_eltwise_d
eltwise descriptor
Definition: dnnl_types.h:1588
dnnl::gru_backward::primitive_desc::dst_iter_desc
memory::desc dst_iter_desc() const
Queries destination iteration memory descriptor.
Definition: dnnl.hpp:5231
dnnl::softmax_backward::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3580
dnnl::vanilla_rnn_forward::primitive_desc::dst_layer_desc
memory::desc dst_layer_desc() const
Queries destination layer memory descriptor.
Definition: dnnl.hpp:4436
dnnl::lstm_backward::desc::desc
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Initializes an LSTM descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:4807
dnnl_bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition: dnnl_types.h:73
dnnl::softmax_backward::primitive_desc::diff_src_desc
memory::desc diff_src_desc() const
Queries diff source memory descriptor.
Definition: dnnl.hpp:3575
dnnl_blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: dnnl_types.h:94
dnnl::shuffle_forward::desc::desc
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Initializes a shuffle descriptor for forward propagation using prop_kind, memory descriptor data_desc...
Definition: dnnl.hpp:5613
dnnl::lstm_backward::primitive_desc::diff_src_iter_c_desc
memory::desc diff_src_iter_c_desc() const
Queries diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:4944
dnnl::engine::kind::cpu
CPU engine.
dnnl_memory_desc_init_by_strides
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and strides.
dnnl_x
1D tensor, an alias to dnnl_a
Definition: dnnl_types.h:337
dnnl_format_tag_last
Just a sentinel, not real memory format tag.
Definition: dnnl_types.h:332
dnnl::query::rnn_d
rnn descriptor
dnnl::gru_backward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source iter memory descriptor.
Definition: dnnl.hpp:5200
dnnl_post_ops_destroy
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Deletes a post_ops sequence.
dnnl_lstm_forward_desc_init
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes an LSTM descriptor rnn_desc for forward propagation using prop_kind, direction,...
dnnl_dhwio
5D CNN weights tensor, an alias to dnnl_cdeba
Definition: dnnl_types.h:386
dnnl_primitive_execute
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive using a stream, and nargs arguments args.
dnnl::memory::format_kind::any
Unspecified format kind.
dnnl_pooling_forward_desc_init
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
dnnl::deconvolution_backward_weights::primitive_desc::diff_dst_desc
memory::desc diff_dst_desc() const
Queries diff destination memory descriptor.
Definition: dnnl.hpp:3022
dnnl::layer_normalization_forward::primitive_desc::src_desc
memory::desc src_desc() const
Queries source memory descriptor.
Definition: dnnl.hpp:3888
dnnl::algorithm::pooling_avg_include_padding
Average pooling include padding.
dnnl::convolution_backward_data::desc::desc
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Initializes a descriptor for convolution backward propagation using aalgorithm, memory descriptors,...
Definition: dnnl.hpp:2342
dnnl::primitive_attr::set_rnn_weights_qparams
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scales weights_scales for RNN weights tensors.
Definition: dnnl.hpp:808
dnnl_unimplemented
The operation failed because requested functionality is not implemented.
Definition: dnnl_types.h:57
dnnl_decab
permuted 5D tensor
Definition: dnnl_types.h:205
dnnl::lstm_backward::primitive_desc::bias_desc
memory::desc bias_desc() const
Queries bias memory descriptor.
Definition: dnnl.hpp:4901
dnnl::memory::format_tag::ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
dnnl::lstm_backward::primitive_desc::src_iter_desc
memory::desc src_iter_desc() const
Queries source recurrent hidden state memory descriptor.
Definition: dnnl.hpp:4878
dnnl::prop_kind::backward
Backward propagation (with respect to all parameters).
dnnl_vanilla_lstm
LSTM cell.
Definition: dnnl_types.h:707
dnnl_nChw4c
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b
Definition: dnnl_types.h:450
dnnl::vanilla_rnn_backward::primitive_desc::weights_iter_desc
memory::desc weights_iter_desc() const
Queries weights iteration memory descriptor.
Definition: dnnl.hpp:4563
dnnl_convolution
A convolution primitive.
Definition: dnnl_types.h:626