Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.17
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
114  undefined_primitive = mkldnn_undefined_primitive,
116  view = mkldnn_view,
119  concat_inplace = mkldnn_concat_inplace,
120  sum = mkldnn_sum,
121  convolution = mkldnn_convolution,
122  deconvolution = mkldnn_deconvolution,
123  shuffle = mkldnn_shuffle,
124  eltwise = mkldnn_eltwise,
125  relu = mkldnn_relu,
126  softmax = mkldnn_softmax,
127  pooling = mkldnn_pooling,
128  lrn = mkldnn_lrn,
129  batch_normalization = mkldnn_batch_normalization,
130  inner_product = mkldnn_inner_product,
131  convolution_relu = mkldnn_convolution_relu,
132  rnn = mkldnn_rnn,
133  };
134 
136  struct at {
144 
145  at(const primitive &aprimitive, size_t at = 0)
146  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
148  inline operator primitive() const;
149  };
150 
152  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
153  // TODO: use the C++ API wrapper structure.
154 };
155 
157  return static_cast<mkldnn_primitive_kind_t>(akind);
158 }
163 struct error: public std::exception {
165  std::string message;
167 
174 
175  error(mkldnn_status_t astatus, std::string amessage,
176  mkldnn_primitive_t aerror_primitive = 0)
177  : status(astatus)
178  , message(amessage)
179  , error_primitive(aerror_primitive, true)
180  {}
181 
189 
190  static void wrap_c_api(mkldnn_status_t status,
191  const std::string &message,
192  mkldnn_primitive_t *error_primitive = 0)
193  {
194  if (status != mkldnn_success) {
195  if (nullptr != error_primitive)
196  throw error(status, message, *error_primitive);
197  else
198  throw error(status, message, nullptr);
199  }
200  }
201 };
202 
203 inline primitive::at::operator primitive() const {
206  mkldnn_primitive_get_output(data.primitive,
207  data.output_index, &output),
208  "could not get an output primitive");
209  return primitive(const_cast<mkldnn_primitive_t>(output), true);
210 }
211 
215  "could not get primitive descriptor by primitive");
216  return pd;
217 }
219 
224 
228 };
229 
231  return static_cast<mkldnn_round_mode_t>(mode);
232 }
233 
236 };
237 
239  return static_cast<mkldnn_padding_kind_t>(kind);
240 }
241 
242 enum prop_kind {
251 };
252 
254  return static_cast<mkldnn_prop_kind_t>(kind);
255 }
256 
257 enum algorithm {
283 };
284 
286  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
287 }
288 
294 };
295 
297  batch_normalization_flag aflag) {
298  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
299 }
300 
307 };
308 
310  return static_cast<mkldnn_rnn_direction_t>(adir);
311 }
312 
313 enum query {
315 
318 
321 
324 
326 
341 
351 };
352 
354  return static_cast<mkldnn_query_t>(aquery);
355 }
356 
358 
364 
365 #ifndef DOXYGEN_SHOULD_SKIP_THIS
366 template <> struct handle_traits<mkldnn_post_ops_t> {
367  static constexpr auto destructor = &mkldnn_post_ops_destroy;
368 };
369 #endif
370 
371 struct post_ops: public handle<mkldnn_post_ops_t> {
373  mkldnn_post_ops_t result;
375  "could not create post operation sequence");
376  reset(result);
377  }
378 
379  int len() const { return mkldnn_post_ops_len(get()); }
380 
381  primitive::kind kind(int index) const {
383  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
384  "post_ops index is out of range");
385  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
386  index));
387  }
388 
389  void append_sum(float scale = 1.) {
391  "could not append sum");
392  }
393 
394  void get_params_sum(int index, float &scale) const {
396  "could not get sum params");
397  }
398 
399  void append_eltwise(float scale, algorithm alg, float alpha,
400  float beta) {
402  convert_to_c(alg), alpha, beta),
403  "could not append eltwise");
404  }
405 
406  void get_params_eltwise(int index, float &scale, algorithm &alg,
407  float &alpha, float &beta) const {
408  mkldnn_alg_kind_t c_alg;
410  &scale, &c_alg, &alpha, &beta),
411  "could not get eltwise params");
412  alg = static_cast<algorithm>(c_alg);
413  }
414 };
415 
416 #ifndef DOXYGEN_SHOULD_SKIP_THIS
417 template <> struct handle_traits<mkldnn_primitive_attr_t> {
418  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
419 };
420 #endif
421 
422 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
424  mkldnn_primitive_attr_t result;
426  "could not create a primitive attr");
427  reset(result);
428  }
429 
431  mkldnn_round_mode_t result;
433  get(), &result), "could not get int output round mode");
434  return round_mode(result);
435  }
436 
439  get(), mkldnn::convert_to_c(mode)),
440  "could not set int output round mode");
441  }
442 
443  void get_output_scales(int &mask, std::vector<float> &scales) const
444  {
445  int count, c_mask;
446  const float *c_scales;
448  &count, &c_mask, &c_scales),
449  "could not get int output scales");
450  scales.resize(count);
451 
452  mask = c_mask;
453  for (int c = 0; c < count; ++c)
454  scales[c] = c_scales[c];
455  }
456 
457  void set_output_scales(int mask, const std::vector<float> &scales)
458  {
460  (int)scales.size(), mask, &scales[0]),
461  "could not set int output scales");
462  }
463 
464  const post_ops get_post_ops() const {
465  post_ops result;
466  const_mkldnn_post_ops_t c_result;
468  "could not get post operation sequence");
469  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
470  return result;
471  }
472 
473  void set_post_ops(post_ops ops) {
475  "could not set post operation sequence");
476  }
477 };
478 
480 
486 
487 #ifndef DOXYGEN_SHOULD_SKIP_THIS
488 template <> struct handle_traits<mkldnn_engine_t> {
489  static constexpr auto destructor = &mkldnn_engine_destroy;
490 };
491 #endif
492 
494 struct engine: public handle<mkldnn_engine_t> {
495  friend class primitive;
496  // gcc bug??? using handle::handle;
497 
499  enum kind {
503  cpu = mkldnn_cpu,
504  };
505 
509 
510  static size_t get_count(kind akind) {
511  return mkldnn_engine_get_count(convert_to_c(akind));
512  }
513 
519 
520  engine(kind akind, size_t index) {
521  mkldnn_engine_t aengine;
523  mkldnn_engine_create(&aengine,
524  convert_to_c(akind), index),
525  "could not create an engine");
526  reset(aengine);
527  }
528 
529  explicit engine(const mkldnn_engine_t& aengine)
530  : handle(aengine, true) {}
531 
533  mkldnn_engine_t engine_q;
536  mkldnn::convert_to_c(eengine), 0, &engine_q),
537  "could not get engine from primitive_desc");
538  reset(engine_q, true);
539  }
540 
541  template <class primitive_desc>
542  static engine query(const primitive_desc &pd) {
543  mkldnn_engine_t engine_q;
546  mkldnn::convert_to_c(eengine), 0, &engine_q),
547  "could not get engine from primitive_desc");
548 
549  return engine(engine_q);
550  }
551 
552 private:
553  static mkldnn_engine_kind_t convert_to_c(kind akind) {
554  return static_cast<mkldnn_engine_kind_t>(akind);
555  }
556 };
557 
559 
562 
568 
570 struct memory: public primitive {
571  private:
572  std::shared_ptr<char> _handle;
573 
574  public:
575  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
576 
577  template <typename T> static void validate_dims(std::vector<T> v) {
578  if (v.size() > TENSOR_MAX_DIMS)
580  "invalid dimensions");
581  }
582 
585  enum data_type {
587  f32 = mkldnn_f32,
588  s32 = mkldnn_s32,
589  s16 = mkldnn_s16,
590  s8 = mkldnn_s8,
591  u8 = mkldnn_u8,
592  };
593 
596  enum format {
597  format_undef = mkldnn_format_undef,
598  any = mkldnn_any,
599  blocked = mkldnn_blocked,
600  x = mkldnn_x,
601  nc = mkldnn_nc,
602  ncw = mkldnn_ncw,
603  nwc = mkldnn_nwc,
604  nCw16c = mkldnn_nCw16c,
605  nchw = mkldnn_nchw,
606  nhwc = mkldnn_nhwc,
607  chwn = mkldnn_chwn,
608  nCw8c = mkldnn_nCw8c,
609  nChw8c = mkldnn_nChw8c,
610  nChw16c = mkldnn_nChw16c,
611  ncdhw = mkldnn_ncdhw,
612  ndhwc = mkldnn_ndhwc,
613  nCdhw8c = mkldnn_nCdhw8c,
614  nCdhw16c = mkldnn_nCdhw16c,
615  oi = mkldnn_oi,
616  io = mkldnn_io,
617  oiw = mkldnn_oiw,
618  wio = mkldnn_wio,
619  Owi8o = mkldnn_Owi8o,
620  OIw8o8i = mkldnn_OIw8o8i,
621  OIw8i8o = mkldnn_OIw8i8o,
622  OIw16i16o = mkldnn_OIw16i16o,
623  OIw16o16i = mkldnn_OIw16o16i,
624  Oiw16o = mkldnn_Oiw16o,
625  Owi16o = mkldnn_Owi16o,
626  OIw8i16o2i = mkldnn_OIw8i16o2i,
627  OIw8o16i2o = mkldnn_OIw8o16i2o,
628  IOw16o16i = mkldnn_IOw16o16i,
629  oihw = mkldnn_oihw,
630  ihwo = mkldnn_ihwo,
631  hwio = mkldnn_hwio,
632  hwio_s8s8 = mkldnn_hwio_s8s8,
633  dhwio = mkldnn_dhwio,
634  oidhw = mkldnn_oidhw,
635  OIdhw8i8o = mkldnn_OIdhw8i8o,
636  OIdhw8o8i = mkldnn_OIdhw8o8i,
637  Odhwi8o = mkldnn_Odhwi8o,
638  OIdhw16i16o = mkldnn_OIdhw16i16o,
639  OIdhw16o16i = mkldnn_OIdhw16o16i,
640  Oidhw16o = mkldnn_Oidhw16o,
641  Odhwi16o = mkldnn_Odhwi16o,
642  oIhw8i = mkldnn_oIhw8i,
643  oIhw16i = mkldnn_oIhw16i,
644  oIdhw8i = mkldnn_oIdhw8i,
645  oIdhw16i = mkldnn_oIdhw16i,
646  OIhw8i8o = mkldnn_OIhw8i8o,
647  OIhw16i16o = mkldnn_OIhw16i16o,
648  OIhw8o8i = mkldnn_OIhw8o8i,
649  OIhw16o16i = mkldnn_OIhw16o16i,
650  IOhw16o16i = mkldnn_IOhw16o16i,
651  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
652  OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
653  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
654  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
655  OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
656  Oihw8o = mkldnn_Oihw8o,
657  Oihw16o = mkldnn_Oihw16o,
658  Ohwi8o = mkldnn_Ohwi8o,
659  Ohwi16o = mkldnn_Ohwi16o,
660  OhIw16o4i = mkldnn_OhIw16o4i,
661  goiw = mkldnn_goiw,
662  gOwi8o = mkldnn_gOwi8o,
663  gOIw8o8i = mkldnn_gOIw8o8i,
664  gOIw8i8o = mkldnn_gOIw8i8o,
665  gOIw16i16o = mkldnn_gOIw16i16o,
666  gOIw16o16i = mkldnn_gOIw16o16i,
667  gOiw16o = mkldnn_gOiw16o,
668  gOwi16o = mkldnn_gOwi16o,
669  gOIw8i16o2i = mkldnn_gOIw8i16o2i,
670  gIOw16o16i = mkldnn_gIOw16o16i,
671  gOIw8o16i2o = mkldnn_gOIw8o16i2o,
672  goihw = mkldnn_goihw,
673  hwigo = mkldnn_hwigo,
674  hwigo_s8s8 = mkldnn_hwigo_s8s8,
675  gOIdhw8i8o = mkldnn_gOIdhw8i8o,
676  gOIdhw8o8i = mkldnn_gOIdhw8o8i,
677  gOdhwi8o = mkldnn_gOdhwi8o,
678  gOIhw8i8o = mkldnn_gOIhw8i8o,
679  gOIhw16i16o = mkldnn_gOIhw16i16o,
680  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
681  gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
682  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
683  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
684  gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
685  gOihw8o = mkldnn_gOihw8o,
686  gOihw16o = mkldnn_gOihw16o,
687  gOhwi8o = mkldnn_gOhwi8o,
688  gOhwi16o = mkldnn_gOhwi16o,
689  Goihw8g = mkldnn_Goihw8g,
690  Goihw16g = mkldnn_Goihw16g,
691  gOIhw8o8i = mkldnn_gOIhw8o8i,
692  gOIhw16o16i = mkldnn_gOIhw16o16i,
693  gIOhw16o16i = mkldnn_gIOhw16o16i,
694  gOhIw16o4i = mkldnn_gOhIw16o4i,
695  goidhw = mkldnn_goidhw,
696  gOIdhw16i16o = mkldnn_gOIdhw16i16o,
697  gOIdhw16o16i = mkldnn_gOIdhw16o16i,
698  gOidhw16o = mkldnn_gOidhw16o,
699  gOdhwi16o = mkldnn_gOdhwi16o,
700  ntc = mkldnn_ntc,
701  tnc = mkldnn_tnc,
702  ldsnc = mkldnn_ldsnc,
703  ldigo = mkldnn_ldigo,
704  ldigo_p = mkldnn_ldigo_p,
705  ldgoi = mkldnn_ldgoi,
706  ldgoi_p = mkldnn_ldgoi_p,
707  ldgo = mkldnn_ldgo,
708  wino_fmt = mkldnn_wino_fmt,
709  format_last = mkldnn_format_last,
710  };
711 
713  struct desc {
714  friend struct memory;
717 
723  desc(dims adims, data_type adata_type,
724  format aformat) {
725  validate_dims(adims);
727  mkldnn_memory_desc_init(&data, (int)adims.size(),
728  adims.size() == 0 ? nullptr : &adims[0],
729  convert_to_c(adata_type), convert_to_c(aformat)),
730  "could not initialize a memory descriptor");
731  }
732 
736  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
737  };
738 
740  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
741  friend struct memory;
742 
743  // TODO: make private
745 
747  primitive_desc(const desc &adesc, const engine &aengine) {
748  mkldnn_primitive_desc_t result;
751  &adesc.data, aengine.get()),
752  "could not initialize a memory primitive descriptor");
753  reset(result);
754  }
755 
759  return memory::desc(*memory_d); }
760 
763  size_t get_size() const {
765  }
766 
767  bool operator==(const primitive_desc &other) const {
768  return (0 == mkldnn_memory_primitive_desc_equal(get(),
769  other.get())) ? false : true;
770  }
771 
772  bool operator!=(const primitive_desc &other) const {
773  return !operator==(other);
774  }
775 
776  engine get_engine() { return engine::query(*this); }
777  };
778 
782  memory(const primitive &aprimitive): primitive(aprimitive) {}
786  memory(const primitive_desc &adesc) {
787  mkldnn_primitive_t result;
789  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
790  "could not create a memory primitive");
791  reset(result);
792  auto _malloc = [](size_t size, int alignment) {
793  void *ptr;
794 #ifdef _WIN32
795  ptr = _aligned_malloc(size, alignment);
796  int rc = ((ptr)? 0 : errno);
797 #else
798  int rc = ::posix_memalign(&ptr, alignment, size);
799 #endif /* _WIN32 */
800  return (rc == 0) ? (char*)ptr : nullptr;
801  };
802  auto _free = [](char* p) {
803 #ifdef _WIN32
804  _aligned_free((void*)p);
805 #else
806  ::free((void*)p);
807 #endif /* _WIN32 */
808  };
809  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
810  set_data_handle(_handle.get());
811  }
812 
813  memory(const primitive_desc &adesc, void *ahandle) {
814  mkldnn_primitive_t result;
816  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
817  "could not create a memory primitive");
818  reset(result);
819  set_data_handle(ahandle);
820  }
821 
824  primitive_desc adesc;
827  &cdesc),
828  "could not get primitive descriptor from a memory primitive");
829  /* FIXME: no const_cast should be here */
830  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
831  return adesc;
832  }
833 
836  inline void *get_data_handle() const {
837  void *handle;
839  "could not get native handle");
840  return handle;
841  }
842 
843  inline void set_data_handle(void *handle) const {
845  "could not set native handle");
846  }
847 
848  // Must go away or be private:
850  return static_cast<mkldnn_data_type_t>(adata_type);
851  }
853  return static_cast<mkldnn_memory_format_t>(aformat);
854  }
855 };
856 
858  auto zero = mkldnn_memory_desc_t();
859  zero.primitive_kind = mkldnn_memory;
860  return memory::desc(zero);
861 }
862 
863 inline memory null_memory(engine eng) {
865  return memory({zero, eng}, nullptr);
866 }
867 
869  &aprimitive_desc, int n_inputs, int n_outputs,
870  const std::string &prim_name) {
871  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
872  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
873  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
874  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
875  if (n_outputs_expected > n_outputs ) {
876  std::string message = "could not create " + prim_name +
877  " primitive, not enought output parameters";
878  throw error(mkldnn_invalid_arguments, message, nullptr);
879  }
880  if (n_inputs_expected > n_inputs ) {
881  std::string message = "could not create " + prim_name +
882  " primitive, not enought input parameters";
883  throw error(mkldnn_invalid_arguments, message, nullptr);
884  }
885 }
886 
887 
888 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
889  const_mkldnn_primitive_desc_t aprimitive_pd;
890  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
892  aprimitive_pd);
893 
894  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
895 }
896 
898  return a == memory::convert_to_c(b);
899 }
901  return !(a == b);
902 }
904  return b == a;
905 }
907  return !(a == b);
908 }
909 
911  return a == memory::convert_to_c(b);
912 }
914  return !(a == b);
915 }
917  return b == a;
918 }
920  return !(a == b);
921 }
922 
924 
930 
931 struct reorder : public primitive {
932  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
934  const memory::primitive_desc &output) {
935  mkldnn_primitive_desc_t result;
937  &result, input.get(), output.get()),
938  "could not create a reorder primitive descriptor");
939  reset(result);
940  }
941 
943  const memory::primitive_desc &output,
944  const primitive_attr &aattr) {
945  mkldnn_primitive_desc_t result;
947  &result, input.get(), output.get(), aattr.get()),
948  "could not create a reorder primitive descriptor");
949  reset(result);
950  }
951 
952  engine get_engine() { return engine::query(*this); }
953  };
954 
955  reorder(const primitive_desc &aprimitive_desc,
956  const primitive::at &input, const memory &output) {
957  mkldnn_primitive_t result;
958  mkldnn_primitive_at_t inputs[] = { input.data };
959  const_mkldnn_primitive_t outputs[] = { output.get() };
961  aprimitive_desc.get(), inputs, outputs),
962  "could not create a reorder primitive");
963  reset(result);
964  }
965 
966  reorder(const primitive::at &input, const memory &output) {
967  auto input_mpd = memory(input).get_primitive_desc();
968  auto output_mpd = output.get_primitive_desc();
969 
970  auto reorder_d = primitive_desc(input_mpd, output_mpd);
971 
972  mkldnn_primitive_t result;
973  mkldnn_primitive_at_t inputs[] = { input.data };
974  const_mkldnn_primitive_t outputs[] = { output.get() };
976  reorder_d.get(), inputs, outputs),
977  "could not create a reorder primitive");
978  reset(result);
979  }
980 };
981 
983 
989 
990 struct view : public primitive {
991  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
993  memory::dims offsets) {
994  mkldnn_primitive_desc_t result;
995 
997  &result, input.get(), &dims[0], &offsets[0]),
998  "could not create a view primitive descriptor");
999  reset(result);
1000  }
1001 
1003  memory::primitive_desc adesc;
1004  mkldnn_primitive_desc_t cdesc;
1005  const_mkldnn_primitive_desc_t const_cdesc =
1009  const_cdesc),
1010  "could not clone a dst primitive descriptor");
1011  adesc.reset(cdesc);
1012  return adesc;
1013  }
1014 
1015  engine get_engine() { return engine::query(*this); }
1016  };
1017 
1018  view(const primitive_desc &view_pd, primitive::at input) {
1019  mkldnn_primitive_t result;
1020  mkldnn_primitive_at_t inputs[] = { input.data };
1022  view_pd.get(), inputs, nullptr),
1023  "could not create a view primitive");
1024  reset(result);
1025  }
1026 
1027  view(memory input, memory::dims dims, memory::dims offsets) {
1028  mkldnn_primitive_t result;
1029  primitive_desc view_pd(input.get_primitive_desc(), dims,
1030  offsets);
1031  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1033  view_pd.get(), inputs, nullptr),
1034  "could not create a view primitive");
1035  reset(result);
1036  }
1037 };
1038 
1040 
1046 
1047 struct concat : public primitive {
1048  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1049  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1050  std::vector<memory::primitive_desc> inputs) {
1051  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1052  c_api_inputs.reserve(inputs.size());
1053  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1054  std::transform(inputs.begin(), inputs.end(),
1055  std::back_inserter(c_api_inputs), convert_to_c);
1056  return c_api_inputs;
1057  }
1058 
1059  primitive_desc(const memory::desc &output, int concat_dimension,
1060  std::vector<memory::primitive_desc> inputs) {
1061  mkldnn_primitive_desc_t result;
1062 
1063  auto c_api_inputs = cpp_to_c(inputs);
1064 
1066  &result, &output.data, (int)c_api_inputs.size(),
1067  concat_dimension, &c_api_inputs[0]),
1068  "could not create a concat primitive descriptor");
1069  reset(result);
1070  }
1071 
1072  primitive_desc(int concat_dimension,
1073  std::vector<memory::primitive_desc> inputs) {
1074  mkldnn_primitive_desc_t result;
1075 
1076  auto c_api_inputs = cpp_to_c(inputs);
1077 
1079  &result, nullptr, (int)c_api_inputs.size(),
1080  concat_dimension, &c_api_inputs[0]),
1081  "could not create a concat primitive descriptor");
1082  reset(result);
1083  }
1084 
1086  memory::primitive_desc adesc;
1087  mkldnn_primitive_desc_t cdesc;
1088  const_mkldnn_primitive_desc_t const_cdesc =
1091  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1092  "could not clone a dst primitive descriptor");
1093  adesc.reset(cdesc);
1094  return adesc;
1095  }
1096 
1097  engine get_engine() { return engine::query(*this); }
1098  };
1099 
1100  concat(const primitive_desc &concat_pd,
1101  std::vector<primitive::at> &inputs, const memory &output) {
1102  mkldnn_primitive_t result;
1103 
1104  std::vector<mkldnn_primitive_at_t> p_inputs;
1105  for (size_t i = 0; i < inputs.size(); i++)
1106  p_inputs.push_back(inputs[i].data);
1107  const_mkldnn_primitive_t outputs[] = { output.get() };
1108 
1110  concat_pd.get(), &p_inputs[0], outputs),
1111  "could not create a concat primitive");
1112  reset(result);
1113  }
1114 };
1115 
1117 
1123 
1124 struct sum : public primitive {
1125  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1126  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1127  std::vector<memory::primitive_desc> inputs) {
1128  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1129  c_api_inputs.reserve(inputs.size());
1130  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1131  std::transform(inputs.begin(), inputs.end(),
1132  std::back_inserter(c_api_inputs), convert_to_c);
1133  return c_api_inputs;
1134  }
1135 
1137  const std::vector<float> &scales,
1138  std::vector<memory::primitive_desc> inputs) {
1139  mkldnn_primitive_desc_t result;
1140 
1141  auto c_api_inputs = cpp_to_c(inputs);
1142 
1144  scales.size() == inputs.size() ? mkldnn_success
1146  "number of scales not equal to number of inputs");
1147 
1149  &result, &output.data, (int)c_api_inputs.size(),
1150  &scales[0], &c_api_inputs[0]),
1151  "could not create a sum primitive descriptor");
1152  reset(result);
1153  }
1154 
1155  primitive_desc(const std::vector<float> &scales,
1156  std::vector<memory::primitive_desc> inputs) {
1157  mkldnn_primitive_desc_t result;
1158 
1159  auto c_api_inputs = cpp_to_c(inputs);
1160 
1162  scales.size() == inputs.size() ? mkldnn_success
1164  "number of scales not equal to number of inputs");
1165 
1167  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1168  &c_api_inputs[0]),
1169  "could not create a sum primitive descriptor");
1170  reset(result);
1171  }
1172 
1174  MKLDNN_DEPRECATED
1175  primitive_desc(const memory::desc &output, std::vector<double> scale,
1176  std::vector<memory::primitive_desc> inputs) {
1177  mkldnn_primitive_desc_t result;
1178 
1179  auto c_api_inputs = cpp_to_c(inputs);
1180  auto scale_f = scale_to_float(scale);
1181 
1183  &result, &output.data, (int)c_api_inputs.size(),
1184  &scale_f[0], &c_api_inputs[0]),
1185  "could not create a sum primitive descriptor");
1186  reset(result);
1187  }
1188 
1190  MKLDNN_DEPRECATED
1191  primitive_desc(std::vector<double> scale,
1192  std::vector<memory::primitive_desc> inputs) {
1193  mkldnn_primitive_desc_t result;
1194 
1195  auto c_api_inputs = cpp_to_c(inputs);
1196  auto scale_f = scale_to_float(scale);
1197 
1199  &result, nullptr, (int)c_api_inputs.size(), &scale_f[0],
1200  &c_api_inputs[0]),
1201  "could not create a sum primitive descriptor");
1202  reset(result);
1203  }
1204 
1206  memory::primitive_desc adesc;
1207  mkldnn_primitive_desc_t cdesc;
1208  const_mkldnn_primitive_desc_t const_cdesc =
1212  const_cdesc),
1213  "could not clone a dst primitive descriptor");
1214  adesc.reset(cdesc);
1215  return adesc;
1216  }
1217 
1218  engine get_engine() { return engine::query(*this); }
1219  };
1220 
1221  sum(const primitive_desc &sum_pd,
1222  std::vector<primitive::at> &inputs, const memory &output) {
1223  mkldnn_primitive_t result;
1224 
1225  std::vector<mkldnn_primitive_at_t> p_inputs;
1226  for (size_t i = 0; i < inputs.size(); i++)
1227  p_inputs.push_back(inputs[i].data);
1228  const_mkldnn_primitive_t outputs[] = { output.get() };
1229 
1231  sum_pd.get(), &p_inputs[0], outputs),
1232  "could not create a sum primitive");
1233  reset(result);
1234  }
1235 
1236 private:
1237  static std::vector<float> scale_to_float(const std::vector<double> &vd) {
1238  std::vector<float> vf(vd.size());
1239  std::transform(vd.begin(), vd.end(), vf.begin(),
1240  [=](double x){return (float)x;});
1241  return vf;
1242  }
1243 };
1244 
1246 
1248 
1251 
1254 
1256 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1258  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1259  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1261  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1262  hint_fwd_pd);
1263  error::wrap_c_api(status,
1264  "could not create a primitive descriptor iterator");
1265  pd_iterator.reset(iterator);
1266  fetch_impl();
1267  }
1268 
1269  engine get_engine() { return engine::query(*this); }
1270 
1272  const char *impl_info_str() const {
1273  const char *res;
1275  mkldnn_query_impl_info_str, 0, &res),
1276  "could not query implementation info string");
1277  return res;
1278  }
1279 
1286  bool next_impl() {
1288  pd_iterator.get());
1289  if (status == mkldnn_iterator_ends) return false;
1290  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1291 
1292  fetch_impl();
1293  return true;
1294  }
1295 
1297  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1298  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1300  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1301  [=](query q) { return what == q; }))
1302  throw error(mkldnn_invalid_arguments, "invalid memory query");
1303 
1304  const_mkldnn_primitive_desc_t const_cdesc
1306  mkldnn::convert_to_c(what), idx);
1307 
1308  // TODO: is there a better way to inform about this?
1309  if (const_cdesc == nullptr)
1310  throw error(mkldnn_not_required, "queried memory is not required");
1311 
1312  mkldnn_primitive_desc_t cdesc;
1313  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1314  "could not clone a memory primitive descriptor");
1315 
1317  ret.reset(cdesc);
1318  return ret;
1319  }
1320 
1321  // register specialized queries, e.g. src_primitive_desc()
1322 # define REG_QUERY_MPD(name, what, idx) \
1323  memory::primitive_desc name ## _primitive_desc() const \
1324  { return query_mpd(what ## _pd, idx); }
1325 
1326  private:
1327  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1328  void fetch_impl() {
1329  mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1330  pd_iterator.get());
1332  "could not fetch a primitive descriptor from the iterator");
1333  reset(pd);
1334  }
1335 };
1336 
1338 
1344 
1346  struct desc {
1348  desc(prop_kind aprop_kind, algorithm aalgorithm,
1349  const memory::desc &src_desc,
1350  const memory::desc &weights_desc,
1351  const memory::desc &bias_desc,
1352  const memory::desc &dst_desc,
1353  const memory::dims strides,
1354  const memory::dims padding_l,
1355  const memory::dims padding_r,
1356  const padding_kind apadding_kind) {
1357  memory::validate_dims(strides);
1358  memory::validate_dims(padding_l);
1359  memory::validate_dims(padding_r);
1361  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1362  &src_desc.data, &weights_desc.data, &bias_desc.data,
1363  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1364  mkldnn::convert_to_c(apadding_kind)),
1365  "could not create a convolution forward descriptor");
1366  }
1367  desc(prop_kind aprop_kind, algorithm aalgorithm,
1368  const memory::desc &src_desc,
1369  const memory::desc &weights_desc,
1370  const memory::desc &dst_desc,
1371  const memory::dims strides,
1372  const memory::dims padding_l,
1373  const memory::dims padding_r,
1374  const padding_kind apadding_kind) {
1375  memory::validate_dims(strides);
1376  memory::validate_dims(padding_l);
1377  memory::validate_dims(padding_r);
1379  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1380  &src_desc.data, &weights_desc.data, nullptr,
1381  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1382  mkldnn::convert_to_c(apadding_kind)),
1383  "could not create a convolution forward descriptor");
1384  }
1385  desc(prop_kind aprop_kind, algorithm aalgorithm,
1386  const memory::desc &src_desc,
1387  const memory::desc &weights_desc,
1388  const memory::desc &bias_desc,
1389  const memory::desc &dst_desc,
1390  const memory::dims strides,
1391  const memory::dims dilates,
1392  const memory::dims padding_l,
1393  const memory::dims padding_r,
1394  const padding_kind apadding_kind) {
1395  memory::validate_dims(strides);
1396  memory::validate_dims(dilates);
1397  memory::validate_dims(padding_l);
1398  memory::validate_dims(padding_r);
1401  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1402  &src_desc.data, &weights_desc.data, &bias_desc.data,
1403  &dst_desc.data, &strides[0], &dilates[0],
1404  &padding_l[0], &padding_r[0],
1405  mkldnn::convert_to_c(apadding_kind)),
1406  "could not create a dilated convolution forward descriptor");
1407  }
1408  desc(prop_kind aprop_kind, algorithm aalgorithm,
1409  const memory::desc &src_desc,
1410  const memory::desc &weights_desc,
1411  const memory::desc &dst_desc,
1412  const memory::dims strides,
1413  const memory::dims dilates,
1414  const memory::dims padding_l,
1415  const memory::dims padding_r,
1416  const padding_kind apadding_kind) {
1417  memory::validate_dims(strides);
1418  memory::validate_dims(dilates);
1419  memory::validate_dims(padding_l);
1420  memory::validate_dims(padding_r);
1423  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1424  &src_desc.data, &weights_desc.data, nullptr,
1425  &dst_desc.data, &strides[0], &dilates[0],
1426  &padding_l[0], &padding_r[0],
1427  mkldnn::convert_to_c(apadding_kind)),
1428  "could not create a dilated convolution forward descriptor");
1429  }
1430  };
1431 
1433  primitive_desc(const desc &desc, const engine &e)
1434  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1435 
1436  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1437  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1438 
1439  REG_QUERY_MPD(src, src, 0);
1440  REG_QUERY_MPD(weights, weights, 0);
1441  REG_QUERY_MPD(bias, weights, 1);
1442  REG_QUERY_MPD(dst, dst, 0);
1443  };
1444 
1445  convolution_forward(const primitive_desc &aprimitive_desc,
1446  const primitive::at &src, const primitive::at &weights,
1447  const primitive::at &bias, const memory &dst) {
1448  mkldnn_primitive_t result;
1449  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1450  bias.data };
1451  const_mkldnn_primitive_t outputs[] = { dst.get() };
1453  aprimitive_desc.get(), inputs, outputs),
1454  "could not create a convolution forward bias primitive");
1455  reset(result);
1456  }
1457 
1458  convolution_forward(const primitive_desc &aprimitive_desc,
1459  const primitive::at &src, const primitive::at &weights,
1460  const memory &dst) {
1461  mkldnn_primitive_t result;
1462  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1463  const_mkldnn_primitive_t outputs[] = { dst.get() };
1464  check_num_parameters(aprimitive_desc.get(), 2, 1,
1465  "convolution forward");
1467  aprimitive_desc.get(), inputs, outputs),
1468  "could not create a convolution forward primitive");
1469  reset(result);
1470  }
1471 };
1472 
1474  struct desc {
1476  desc(algorithm aalgorithm,
1477  const memory::desc &diff_src_desc,
1478  const memory::desc &weights_desc,
1479  const memory::desc &diff_dst_desc,
1480  const memory::dims strides,
1481  const memory::dims padding_l,
1482  const memory::dims padding_r,
1483  const padding_kind apadding_kind) {
1484  memory::validate_dims(strides);
1485  memory::validate_dims(padding_l);
1486  memory::validate_dims(padding_r);
1488  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1489  &weights_desc.data, &diff_dst_desc.data,
1490  &strides[0], &padding_l[0], &padding_r[0],
1491  mkldnn::convert_to_c(apadding_kind)),
1492  "could not create a convolution backward data descriptor");
1493  }
1494  desc(algorithm aalgorithm,
1495  const memory::desc &diff_src_desc,
1496  const memory::desc &weights_desc,
1497  const memory::desc &diff_dst_desc,
1498  const memory::dims strides,
1499  const memory::dims dilates,
1500  const memory::dims padding_l,
1501  const memory::dims padding_r,
1502  const padding_kind apadding_kind) {
1503  memory::validate_dims(strides);
1504  memory::validate_dims(dilates);
1505  memory::validate_dims(padding_l);
1506  memory::validate_dims(padding_r);
1509  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1510  &weights_desc.data, &diff_dst_desc.data,
1511  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1512  mkldnn::convert_to_c(apadding_kind)),
1513  "could not create a convolution backward data descriptor");
1514  }
1515  };
1516 
1518  primitive_desc(const desc &desc, const engine &e,
1519  const convolution_forward::primitive_desc &hint_fwd_pd)
1520  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1521 
1522  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1523  const convolution_forward::primitive_desc &hint_fwd_pd)
1524  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1525 
1526  REG_QUERY_MPD(diff_src, diff_src, 0);
1527  REG_QUERY_MPD(weights, weights, 0);
1528  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1529  };
1530 
1532  const primitive::at &diff_dst, const primitive::at &weights,
1533  const memory &diff_src) {
1534  mkldnn_primitive_t result;
1535  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1536  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1537  check_num_parameters(aprimitive_desc.get(), 2, 1,
1538  "convolution backward data");
1540  aprimitive_desc.get(), inputs, outputs),
1541  "could not create a convolution backward data primitive");
1542  reset(result);
1543  }
1544 };
1545 
1547  struct desc {
1549  desc(algorithm aalgorithm,
1550  const memory::desc &src_desc,
1551  const memory::desc &diff_weights_desc,
1552  const memory::desc &diff_bias_desc,
1553  const memory::desc &diff_dst_desc,
1554  const memory::dims strides,
1555  const memory::dims padding_l,
1556  const memory::dims padding_r,
1557  const padding_kind apadding_kind) {
1558  memory::validate_dims(strides);
1559  memory::validate_dims(padding_l);
1560  memory::validate_dims(padding_r);
1562  &data, convert_to_c(aalgorithm), &src_desc.data,
1563  &diff_weights_desc.data, &diff_bias_desc.data,
1564  &diff_dst_desc.data,
1565  &strides[0], &padding_l[0], &padding_r[0],
1566  mkldnn::convert_to_c(apadding_kind)),
1567  "could not create a convolution backward weights descriptor");
1568  }
1569  desc(algorithm aalgorithm,
1570  const memory::desc &src_desc,
1571  const memory::desc &diff_weights_desc,
1572  const memory::desc &diff_dst_desc,
1573  const memory::dims strides,
1574  const memory::dims padding_l,
1575  const memory::dims padding_r,
1576  const padding_kind apadding_kind) {
1577  memory::validate_dims(strides);
1578  memory::validate_dims(padding_l);
1579  memory::validate_dims(padding_r);
1581  &data, convert_to_c(aalgorithm), &src_desc.data,
1582  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1583  &strides[0], &padding_l[0], &padding_r[0],
1584  mkldnn::convert_to_c(apadding_kind)),
1585  "could not create a convolution backward weights descriptor");
1586  }
1587  desc(algorithm aalgorithm,
1588  const memory::desc &src_desc,
1589  const memory::desc &diff_weights_desc,
1590  const memory::desc &diff_bias_desc,
1591  const memory::desc &diff_dst_desc,
1592  const memory::dims strides,
1593  const memory::dims dilates,
1594  const memory::dims padding_l,
1595  const memory::dims padding_r,
1596  const padding_kind apadding_kind) {
1597  memory::validate_dims(strides);
1598  memory::validate_dims(dilates);
1599  memory::validate_dims(padding_l);
1600  memory::validate_dims(padding_r);
1602  &data, convert_to_c(aalgorithm), &src_desc.data,
1603  &diff_weights_desc.data, &diff_bias_desc.data,
1604  &diff_dst_desc.data,
1605  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1606  mkldnn::convert_to_c(apadding_kind)),
1607  "could not create a convolution backward weights descriptor");
1608  }
1609  desc(algorithm aalgorithm,
1610  const memory::desc &src_desc,
1611  const memory::desc &diff_weights_desc,
1612  const memory::desc &diff_dst_desc,
1613  const memory::dims strides,
1614  const memory::dims dilates,
1615  const memory::dims padding_l,
1616  const memory::dims padding_r,
1617  const padding_kind apadding_kind) {
1618  memory::validate_dims(strides);
1619  memory::validate_dims(dilates);
1620  memory::validate_dims(padding_l);
1621  memory::validate_dims(padding_r);
1623  &data, convert_to_c(aalgorithm), &src_desc.data,
1624  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1625  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1626  mkldnn::convert_to_c(apadding_kind)),
1627  "could not create a convolution backward weights descriptor");
1628  }
1629 
1630  };
1631 
1633  primitive_desc(const desc &desc, const engine &e,
1634  const convolution_forward::primitive_desc &hint_fwd_pd)
1635  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1636 
1637  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1638  const convolution_forward::primitive_desc &hint_fwd_pd)
1639  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1640 
1641  REG_QUERY_MPD(src, src, 0);
1642  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1643  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1644  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1645  };
1646 
1648  const primitive::at &src, const primitive::at &diff_dst,
1649  const memory &diff_weights, const memory &diff_bias) {
1650  mkldnn_primitive_t result;
1651  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1652  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1653  diff_bias.get() };
1654  check_num_parameters(aprimitive_desc.get(), 2, 2,
1655  "convolution backward weights");
1657  aprimitive_desc.get(), inputs, outputs),
1658  "could not create a convolution backward weights primitive");
1659  reset(result);
1660  }
1662  const primitive::at &src, const primitive::at &diff_dst,
1663  const memory &diff_weights) {
1664  mkldnn_primitive_t result;
1665  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1666  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1667  check_num_parameters(aprimitive_desc.get(), 2, 1,
1668  "convolution backward weights");
1670  aprimitive_desc.get(), inputs, outputs),
1671  "could not create a convolution backward weights primitive");
1672  reset(result);
1673  }
1674 };
1675 
1681  struct desc {
1683 
1685  const float negative_slope) {
1687  &conv_desc.data, negative_slope),
1688  "could not create a convolution_relu_forward descriptor");
1689  }
1690  };
1691 
1693  primitive_desc(const desc &desc, const engine &e)
1694  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1695 
1696  REG_QUERY_MPD(src, src, 0);
1697  REG_QUERY_MPD(weights, weights, 0);
1698  REG_QUERY_MPD(bias, weights, 1);
1699  REG_QUERY_MPD(dst, dst, 0);
1700  };
1701 
1703  MKLDNN_DEPRECATED
1705  const primitive::at &src, const primitive::at &weights,
1706  const primitive::at &bias, const memory &dst) {
1707  mkldnn_primitive_t result;
1708  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1709  bias.data };
1710  const_mkldnn_primitive_t outputs[] = { dst.get() };
1711  check_num_parameters(aprimitive_desc.get(), 3, 1,
1712  "convolution relu forward");
1714  aprimitive_desc.get(), inputs, outputs),
1715  "could not create a convolution relu forward primitive");
1716  reset(result);
1717  }
1718 
1720  MKLDNN_DEPRECATED
1722  const primitive::at &src, const primitive::at &weights,
1723  const memory &dst) {
1724  mkldnn_primitive_t result;
1725  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1726  const_mkldnn_primitive_t outputs[] = { dst.get() };
1727  check_num_parameters(aprimitive_desc.get(), 2, 1,
1728  "convolution relu forward");
1730  aprimitive_desc.get(), inputs, outputs),
1731  "could not create a convolution relu forward primitive");
1732  reset(result);
1733  }
1734 };
1735 
1737 //
1743 
1745  struct desc {
1747  desc(prop_kind aprop_kind, algorithm aalgorithm,
1748  const memory::desc &src_desc,
1749  const memory::desc &weights_desc,
1750  const memory::desc &bias_desc,
1751  const memory::desc &dst_desc,
1752  const memory::dims strides,
1753  const memory::dims padding_l,
1754  const memory::dims padding_r,
1755  const padding_kind apadding_kind) {
1756  memory::validate_dims(strides);
1757  memory::validate_dims(padding_l);
1758  memory::validate_dims(padding_r);
1760  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1761  &src_desc.data, &weights_desc.data, &bias_desc.data,
1762  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1763  mkldnn::convert_to_c(apadding_kind)),
1764  "could not create a deconvolution forward descriptor");
1765  }
1766  desc(prop_kind aprop_kind, algorithm aalgorithm,
1767  const memory::desc &src_desc,
1768  const memory::desc &weights_desc,
1769  const memory::desc &dst_desc,
1770  const memory::dims strides,
1771  const memory::dims padding_l,
1772  const memory::dims padding_r,
1773  const padding_kind apadding_kind) {
1774  memory::validate_dims(strides);
1775  memory::validate_dims(padding_l);
1776  memory::validate_dims(padding_r);
1778  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1779  &src_desc.data, &weights_desc.data, nullptr,
1780  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1781  mkldnn::convert_to_c(apadding_kind)),
1782  "could not create a deconvolution forward descriptor");
1783  }
1784  desc(prop_kind aprop_kind, algorithm aalgorithm,
1785  const memory::desc &src_desc,
1786  const memory::desc &weights_desc,
1787  const memory::desc &bias_desc,
1788  const memory::desc &dst_desc,
1789  const memory::dims strides,
1790  const memory::dims dilates,
1791  const memory::dims padding_l,
1792  const memory::dims padding_r,
1793  const padding_kind apadding_kind) {
1794  memory::validate_dims(strides);
1795  memory::validate_dims(dilates);
1796  memory::validate_dims(padding_l);
1797  memory::validate_dims(padding_r);
1799  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1800  &src_desc.data, &weights_desc.data, &bias_desc.data,
1801  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1802  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1803  "could not create a dilated deconvolution forward descriptor");
1804  }
1805  desc(prop_kind aprop_kind, algorithm aalgorithm,
1806  const memory::desc &src_desc,
1807  const memory::desc &weights_desc,
1808  const memory::desc &dst_desc,
1809  const memory::dims strides,
1810  const memory::dims dilates,
1811  const memory::dims padding_l,
1812  const memory::dims padding_r,
1813  const padding_kind apadding_kind) {
1814  memory::validate_dims(strides);
1815  memory::validate_dims(dilates);
1816  memory::validate_dims(padding_l);
1817  memory::validate_dims(padding_r);
1819  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1820  &src_desc.data, &weights_desc.data, nullptr,
1821  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1822  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1823  "could not create a dilated deconvolution forward descriptor");
1824  }
1825  };
1826 
1828  primitive_desc(const desc &desc, const engine &e)
1829  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1830 
1831  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1832  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1833 
1834  REG_QUERY_MPD(src, src, 0);
1835  REG_QUERY_MPD(weights, weights, 0);
1836  REG_QUERY_MPD(bias, weights, 1);
1837  REG_QUERY_MPD(dst, dst, 0);
1838  };
1839 
1840  deconvolution_forward(const primitive_desc &aprimitive_desc,
1841  const primitive::at &src, const primitive::at &weights,
1842  const primitive::at &bias, const memory &dst) {
1843  mkldnn_primitive_t result;
1844  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1845  bias.data };
1846  const_mkldnn_primitive_t outputs[] = { dst.get() };
1847  check_num_parameters(aprimitive_desc.get(), 3, 1,
1848  "deconvolution forward");
1850  aprimitive_desc.get(), inputs, outputs),
1851  "could not create a deconvolution forward bias primitive");
1852  reset(result);
1853  }
1854 
1855  deconvolution_forward(const primitive_desc &aprimitive_desc,
1856  const primitive::at &src, const primitive::at &weights,
1857  const memory &dst) {
1858  mkldnn_primitive_t result;
1859  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1860  const_mkldnn_primitive_t outputs[] = { dst.get() };
1861  check_num_parameters(aprimitive_desc.get(), 2, 1,
1862  "deconvolution forward");
1864  aprimitive_desc.get(), inputs, outputs),
1865  "could not create a deconvolution forward primitive");
1866  reset(result);
1867  }
1868 };
1869 
1871  struct desc {
1873  desc(algorithm aalgorithm,
1874  const memory::desc &diff_src_desc,
1875  const memory::desc &weights_desc,
1876  const memory::desc &diff_dst_desc,
1877  const memory::dims strides,
1878  const memory::dims padding_l,
1879  const memory::dims padding_r,
1880  const padding_kind apadding_kind) {
1881  memory::validate_dims(strides);
1882  memory::validate_dims(padding_l);
1883  memory::validate_dims(padding_r);
1885  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1886  &weights_desc.data, &diff_dst_desc.data,
1887  &strides[0], &padding_l[0], &padding_r[0],
1888  mkldnn::convert_to_c(apadding_kind)),
1889  "could not create a deconvolution backward data descriptor");
1890  }
1891  desc(algorithm aalgorithm,
1892  const memory::desc &diff_src_desc,
1893  const memory::desc &weights_desc,
1894  const memory::desc &diff_dst_desc,
1895  const memory::dims strides,
1896  const memory::dims dilates,
1897  const memory::dims padding_l,
1898  const memory::dims padding_r,
1899  const padding_kind apadding_kind) {
1900  memory::validate_dims(strides);
1901  memory::validate_dims(dilates);
1902  memory::validate_dims(padding_l);
1903  memory::validate_dims(padding_r);
1905  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1906  &weights_desc.data, &diff_dst_desc.data,
1907  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1908  mkldnn::convert_to_c(apadding_kind)),
1909  "could not create a dilated deconvolution backward data descriptor");
1910  }
1911  };
1912 
1914  primitive_desc(const desc &desc, const engine &e,
1915  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1916  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1917 
1918  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1919  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1920  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1921 
1922  REG_QUERY_MPD(diff_src, diff_src, 0);
1923  REG_QUERY_MPD(weights, weights, 0);
1924  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1925  };
1926 
1928  const primitive::at &diff_dst, const primitive::at &weights,
1929  const memory &diff_src) {
1930  mkldnn_primitive_t result;
1931  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1932  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1933  check_num_parameters(aprimitive_desc.get(), 2, 1,
1934  "deconvolution backward data");
1936  aprimitive_desc.get(), inputs, outputs),
1937  "could not create a deconvolution backward data primitive");
1938  reset(result);
1939  }
1940 };
1941 
1943  struct desc {
1945  desc(algorithm aalgorithm,
1946  const memory::desc &src_desc,
1947  const memory::desc &diff_weights_desc,
1948  const memory::desc &diff_bias_desc,
1949  const memory::desc &diff_dst_desc,
1950  const memory::dims strides,
1951  const memory::dims padding_l,
1952  const memory::dims padding_r,
1953  const padding_kind apadding_kind) {
1954  memory::validate_dims(strides);
1955  memory::validate_dims(padding_l);
1956  memory::validate_dims(padding_r);
1958  &data, convert_to_c(aalgorithm), &src_desc.data,
1959  &diff_weights_desc.data, &diff_bias_desc.data,
1960  &diff_dst_desc.data,
1961  &strides[0], &padding_l[0], &padding_r[0],
1962  mkldnn::convert_to_c(apadding_kind)),
1963  "could not create a deconvolution backward weights descriptor");
1964  }
1965  desc(algorithm aalgorithm,
1966  const memory::desc &src_desc,
1967  const memory::desc &diff_weights_desc,
1968  const memory::desc &diff_dst_desc,
1969  const memory::dims strides,
1970  const memory::dims padding_l,
1971  const memory::dims padding_r,
1972  const padding_kind apadding_kind) {
1973  memory::validate_dims(strides);
1974  memory::validate_dims(padding_l);
1975  memory::validate_dims(padding_r);
1977  &data, convert_to_c(aalgorithm), &src_desc.data,
1978  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1979  &strides[0], &padding_l[0], &padding_r[0],
1980  mkldnn::convert_to_c(apadding_kind)),
1981  "could not create a deconvolution backward weights descriptor");
1982  }
1983  desc(algorithm aalgorithm,
1984  const memory::desc &src_desc,
1985  const memory::desc &diff_weights_desc,
1986  const memory::desc &diff_bias_desc,
1987  const memory::desc &diff_dst_desc,
1988  const memory::dims strides,
1989  const memory::dims dilates,
1990  const memory::dims padding_l,
1991  const memory::dims padding_r,
1992  const padding_kind apadding_kind) {
1993  memory::validate_dims(strides);
1994  memory::validate_dims(dilates);
1995  memory::validate_dims(padding_l);
1996  memory::validate_dims(padding_r);
1998  &data, convert_to_c(aalgorithm), &src_desc.data,
1999  &diff_weights_desc.data, &diff_bias_desc.data,
2000  &diff_dst_desc.data,
2001  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2002  mkldnn::convert_to_c(apadding_kind)),
2003  "could not create a dilated deconvolution backward weights descriptor");
2004  }
2005  desc(algorithm aalgorithm,
2006  const memory::desc &src_desc,
2007  const memory::desc &diff_weights_desc,
2008  const memory::desc &diff_dst_desc,
2009  const memory::dims strides,
2010  const memory::dims dilates,
2011  const memory::dims padding_l,
2012  const memory::dims padding_r,
2013  const padding_kind apadding_kind) {
2014  memory::validate_dims(strides);
2015  memory::validate_dims(dilates);
2016  memory::validate_dims(padding_l);
2017  memory::validate_dims(padding_r);
2019  &data, convert_to_c(aalgorithm), &src_desc.data,
2020  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
2021  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
2022  mkldnn::convert_to_c(apadding_kind)),
2023  "could not create a dilated deconvolution backward weights descriptor");
2024  }
2025  };
2026 
2028  primitive_desc(const desc &desc, const engine &e,
2029  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2030  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2031 
2032  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2033  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2034  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2035 
2036  REG_QUERY_MPD(src, src, 0);
2037  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2038  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2039  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2040  };
2041 
2043  const primitive::at &src, const primitive::at &diff_dst,
2044  const memory &diff_weights, const memory &diff_bias) {
2045  mkldnn_primitive_t result;
2046  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2047  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2048  diff_bias.get() };
2049  check_num_parameters(aprimitive_desc.get(), 2, 2,
2050  "deconvolution backward weights");
2052  aprimitive_desc.get(), inputs, outputs),
2053  "could not create a deconvolution backward weights primitive");
2054  reset(result);
2055  }
2057  const primitive::at &src, const primitive::at &diff_dst,
2058  const memory &diff_weights) {
2059  mkldnn_primitive_t result;
2060  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2061  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2062  check_num_parameters(aprimitive_desc.get(), 2, 1,
2063  "deconvolution backward weights");
2065  aprimitive_desc.get(), inputs, outputs),
2066  "could not create a deconvolution backward weights primitive");
2067  reset(result);
2068  }
2069 };
2070 
2072 
2079 
2080 struct lrn_forward : public primitive {
2081  struct desc {
2083  desc(prop_kind aprop_kind, algorithm aalgorithm,
2084  const memory::desc &src_desc,
2085  int local_size, float alpha, float beta, float k)
2086  {
2088  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2089  &src_desc.data, local_size, alpha, beta, k),
2090  "could not create a lrn forward descriptor");
2091  }
2092  desc(prop_kind aprop_kind, algorithm aalgorithm,
2093  const memory::desc &src_desc,
2094  int local_size, float alpha, float beta)
2095  {
2097  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2098  &src_desc.data, local_size, alpha, beta, float(1.0)),
2099  "could not create a lrn forward descriptor");
2100  }
2101  };
2102 
2104  primitive_desc(const desc &desc, const engine &e)
2105  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2106 
2107  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2108  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2109 
2110  REG_QUERY_MPD(src, src, 0);
2111  REG_QUERY_MPD(dst, dst, 0);
2112  REG_QUERY_MPD(workspace, workspace, 0);
2113  };
2114 
2115  lrn_forward(const primitive_desc &aprimitive_desc,
2116  const primitive::at &src, const memory &workspace,
2117  const memory &dst) {
2118  mkldnn_primitive_t result;
2119  mkldnn_primitive_at_t inputs[] = { src.data };
2120  const_mkldnn_primitive_t outputs[] = { dst.get(),
2121  workspace.get() };
2122  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2124  aprimitive_desc.get(), inputs, outputs),
2125  "could not create a lrn forward primitive");
2126  reset(result);
2127  }
2128 
2129  lrn_forward(const primitive_desc &aprimitive_desc,
2130  const primitive::at &src, const memory &dst) {
2131  mkldnn_primitive_t result;
2132  mkldnn_primitive_at_t inputs[] = { src.data };
2133  const_mkldnn_primitive_t outputs[] = { dst.get() };
2134  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2136  aprimitive_desc.get(), inputs, outputs),
2137  "could not create a lrn forward primitive");
2138  reset(result);
2139  }
2140 };
2141 
2142 struct lrn_backward : public primitive {
2143  struct desc {
2145  desc(algorithm aalgorithm,
2146  const memory::desc &data_desc,
2147  const memory::desc &diff_data_desc,
2148  int local_size, float alpha, float beta, float k)
2149  {
2151  convert_to_c(aalgorithm), &diff_data_desc.data,
2152  &data_desc.data, local_size, alpha, beta, k),
2153  "could not create a lrn backward descriptor");
2154  }
2155  desc(algorithm aalgorithm,
2156  const memory::desc &data_desc,
2157  const memory::desc &diff_data_desc,
2158  int local_size, float alpha, float beta)
2159  {
2161  convert_to_c(aalgorithm), &diff_data_desc.data,
2162  &data_desc.data, local_size, alpha, beta, float(1.0)),
2163  "could not create a lrn backward descriptor");
2164  }
2165  };
2166 
2168  primitive_desc(const desc &desc, const engine &e,
2169  const lrn_forward::primitive_desc &hint_fwd_pd)
2170  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2171 
2172  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2173  const lrn_forward::primitive_desc &hint_fwd_pd)
2174  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2175 
2176  REG_QUERY_MPD(diff_src, diff_src, 0);
2177  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2178  REG_QUERY_MPD(workspace, workspace, 0);
2179  };
2180 
2181  lrn_backward(const primitive_desc &aprimitive_desc,
2182  const primitive::at &src, const primitive::at &diff_dst,
2183  const primitive::at &workspace, const memory &diff_src) {
2184  mkldnn_primitive_t result;
2185  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2186  workspace.data };
2187  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2188  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2190  aprimitive_desc.get(), inputs, outputs),
2191  "could not create a lrn backward primitive");
2192  reset(result);
2193  }
2194 
2195  lrn_backward(const primitive_desc &aprimitive_desc,
2196  const primitive::at &src, const primitive::at &diff_dst,
2197  const memory &diff_src) {
2198  mkldnn_primitive_t result;
2199  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2200  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2201  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2203  aprimitive_desc.get(), inputs, outputs),
2204  "could not create a lrn backward primitive");
2205  reset(result);
2206  }
2207 };
2208 
2210 
2216 
2217 struct pooling_forward : public primitive {
2218  struct desc {
2220  desc(prop_kind aprop_kind, algorithm aalgorithm,
2221  const memory::desc &src_desc,
2222  const memory::desc &dst_desc,
2223  const memory::dims strides,
2224  const memory::dims kernel,
2225  const memory::dims padding_l,
2226  const memory::dims padding_r,
2227  const padding_kind apadding_kind) {
2228  memory::validate_dims(strides);
2229  memory::validate_dims(kernel);
2230  memory::validate_dims(padding_l);
2231  memory::validate_dims(padding_r);
2233  mkldnn::convert_to_c(aprop_kind),
2234  convert_to_c(aalgorithm),
2235  &src_desc.data, &dst_desc.data,
2236  &strides[0], &kernel[0],
2237  &padding_l[0], &padding_r[0],
2238  mkldnn::convert_to_c(apadding_kind)),
2239  "could not init a forward pooling descriptor");
2240  }
2241  };
2242 
2244  primitive_desc(const desc &desc, const engine &e)
2245  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2246 
2247  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2248  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2249 
2250  REG_QUERY_MPD(src, src, 0);
2251  REG_QUERY_MPD(dst, dst, 0);
2252  REG_QUERY_MPD(workspace, workspace, 0);
2253  };
2254 
2255  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2256  const memory &dst) {
2257  mkldnn_primitive_t result;
2258  mkldnn_primitive_at_t inputs[] = { src.data };
2259  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2260  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2262  aprimitive_desc.get(), inputs, outputs),
2263  "could not create a pooling forward primitive");
2264  reset(result);
2265  }
2266 
2267  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2268  const memory &dst, const memory &workspace) {
2269  mkldnn_primitive_t result;
2270  mkldnn_primitive_at_t inputs[] = { src.data };
2271  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2272  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2274  aprimitive_desc.get(), inputs, outputs),
2275  "could not create a pooling forward primitive");
2276  reset(result);
2277  }
2278 };
2279 
2280 struct pooling_backward : public primitive {
2281  struct desc {
2283  desc(algorithm aalgorithm,
2284  const memory::desc &diff_src_desc,
2285  const memory::desc &diff_dst_desc,
2286  const memory::dims &strides,
2287  const memory::dims &kernel,
2288  const memory::dims &padding_l,
2289  const memory::dims &padding_r,
2290  const padding_kind apadding_kind) {
2291  memory::validate_dims(strides);
2292  memory::validate_dims(kernel);
2293  memory::validate_dims(padding_l);
2294  memory::validate_dims(padding_r);
2296  convert_to_c(aalgorithm),
2297  &diff_src_desc.data, &diff_dst_desc.data,
2298  &strides[0], &kernel[0],
2299  &padding_l[0], &padding_r[0],
2300  mkldnn::convert_to_c(apadding_kind)),
2301  "could not init a backward pooling descriptor");
2302  }
2303  };
2304 
2306  primitive_desc(const desc &desc, const engine &e,
2307  const pooling_forward::primitive_desc &hint_fwd_pd)
2308  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2309 
2310  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2311  const pooling_forward::primitive_desc &hint_fwd_pd)
2312  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2313 
2314  REG_QUERY_MPD(diff_src, diff_src, 0);
2315  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2316  REG_QUERY_MPD(workspace, workspace, 0);
2317  };
2318 
2319  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2320  const memory &diff_src) {
2321  mkldnn_primitive_t result;
2322  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2323  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2324  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2326  aprimitive_desc.get(), inputs, outputs),
2327  "could not create a pooling backward primitive");
2328  reset(result);
2329  }
2330 
2331  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2332  const primitive::at &workspace, const memory &diff_src) {
2333  mkldnn_primitive_t result;
2334  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2335  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2336  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2338  aprimitive_desc.get(), inputs, outputs),
2339  "could not create a pooling backward primitive");
2340  reset(result);
2341  }
2342 };
2343 
2345 
2352 
2353 struct eltwise_forward : public primitive {
2354  struct desc {
2356  template <typename T>
2357  desc(prop_kind aprop_kind, algorithm alg_kind,
2358  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2360  mkldnn::convert_to_c(aprop_kind),
2361  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2362  static_cast<float>(alpha), static_cast<float>(beta)),
2363  "could not create a eltwise forward descriptor");
2364  }
2365 
2367  template <typename T>
2368  MKLDNN_DEPRECATED
2369  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2370  T negative_slope)
2371  : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
2372  };
2373 
2375  primitive_desc(const desc &desc, const engine &e)
2376  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2377 
2378  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2379  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2380 
2381  REG_QUERY_MPD(src, src, 0);
2382  REG_QUERY_MPD(dst, dst, 0);
2383  };
2384 
2385  eltwise_forward(const primitive_desc &aprimitive_desc,
2386  const primitive::at &src, const memory &dst) {
2387  mkldnn_primitive_t result;
2388  mkldnn_primitive_at_t inputs[] = { src.data };
2389  const_mkldnn_primitive_t outputs[] = { dst.get() };
2390  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2392  aprimitive_desc.get(), inputs, outputs),
2393  "could not create a eltwise forward primitive");
2394  reset(result);
2395  }
2396 };
2397 
2399 
2400 struct eltwise_backward : public primitive {
2401  struct desc {
2403 
2404  template <typename T>
2405  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2406  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2408  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2409  &data_desc.data, static_cast<float>(alpha),
2410  static_cast<float>(beta)),
2411  "could not create a eltwise backward descriptor");
2412  }
2413 
2415  template <typename T>
2416  MKLDNN_DEPRECATED
2417  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
2418  T negative_slope): desc(eltwise_relu, diff_data_desc, data_desc,
2419  negative_slope) {}
2420  };
2421 
2423  primitive_desc(const desc &desc, const engine &e,
2424  const eltwise_forward::primitive_desc &hint_fwd_pd)
2425  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2426 
2427  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2428  const eltwise_forward::primitive_desc &hint_fwd_pd)
2429  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2430 
2431  REG_QUERY_MPD(src, src, 0);
2432  REG_QUERY_MPD(diff_src, diff_src, 0);
2433  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2434  };
2435 
2436  eltwise_backward(const primitive_desc &aprimitive_desc,
2437  const primitive::at &src, const primitive::at &diff_dst,
2438  const memory &diff_src) {
2439  mkldnn_primitive_t result;
2440  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2441  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2442  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2444  aprimitive_desc.get(), inputs, outputs),
2445  "could not create a eltwise backward primitive");
2446  reset(result);
2447  }
2448 };
2449 
2451 
2453 
2459 
2460 struct softmax_forward : public primitive {
2461  struct desc {
2463  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2464  int softmax_axis) {
2466  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2467  softmax_axis),
2468  "could not create a softmax forward descriptor");
2469  }
2470  };
2471 
2473  primitive_desc(const desc &desc, const engine &e)
2474  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2475 
2476  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2477  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2478 
2479  REG_QUERY_MPD(src, src, 0);
2480  REG_QUERY_MPD(dst, dst, 0);
2481  };
2482 
2483  softmax_forward(const primitive_desc &aprimitive_desc,
2484  const primitive::at &src, const memory &dst) {
2485  mkldnn_primitive_t result;
2486  mkldnn_primitive_at_t inputs[] = { src.data };
2487  const_mkldnn_primitive_t outputs[] = { dst.get() };
2488  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2490  aprimitive_desc.get(), inputs, outputs),
2491  "could not create a softmax forward primitive");
2492  reset(result);
2493  }
2494 };
2495 
2496 struct softmax_backward : public primitive {
2497  struct desc {
2499  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2500  int softmax_axis) {
2502  &diff_desc.data, &data_desc.data, softmax_axis),
2503  "could not init a backward softmax descriptor");
2504  }
2505  };
2506 
2508  primitive_desc(const desc &desc, const engine &e,
2509  const softmax_forward::primitive_desc &hint_fwd_pd)
2510  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2511 
2512  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2513  const softmax_forward::primitive_desc &hint_fwd_pd)
2514  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2515 
2516  REG_QUERY_MPD(dst, dst, 0);
2517  REG_QUERY_MPD(diff_src, diff_src, 0);
2518  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2519  REG_QUERY_MPD(workspace, workspace, 0);
2520  };
2521 
2522  softmax_backward(const primitive_desc &aprimitive_desc,
2523  const primitive::at &dst, const primitive::at &diff_dst,
2524  const memory &diff_src) {
2525  mkldnn_primitive_t result;
2526  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2527  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2529  aprimitive_desc.get(), inputs, outputs),
2530  "could not create a softmax backward primitive");
2531  reset(result);
2532  }
2533 };
2534 
2536 
2542 
2544  struct desc {
2546  template <typename T>
2547  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2548  unsigned flags) {
2551  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2552  static_cast<float>(epsilon), flags),
2553  "could not create a batch normalization forward descriptor");
2554  }
2555  };
2556 
2558  primitive_desc(const desc &desc, const engine &e)
2559  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2560 
2561  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2562  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2563 
2564  REG_QUERY_MPD(src, src, 0);
2565  REG_QUERY_MPD(weights, weights, 0);
2566  REG_QUERY_MPD(dst, dst, 0);
2567  REG_QUERY_MPD(workspace, workspace, 0);
2568 
2570  { return stat_primitive_desc(mean); }
2572  { return stat_primitive_desc(var); }
2573 
2574  private:
2575  enum { mean = 1, var = 2, };
2576  memory::primitive_desc stat_primitive_desc(int kind) const {
2580  "could not get a batch-normalization descriptor");
2581  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2582  }
2583  };
2584 
2586  const primitive::at &src, const primitive::at &mean,
2587  const primitive::at &variance, const primitive::at &weights,
2588  const memory &dst) {
2589  mkldnn_primitive_t result;
2590  mkldnn_primitive_at_t inputs[] = { src.data,
2591  mean.data, variance.data, weights.data };
2592  const_mkldnn_primitive_t outputs[] = { dst.get() };
2593  check_num_parameters(aprimitive_desc.get(), 4, 1,
2594  "batch normalization forward");
2596  aprimitive_desc.get(), inputs, outputs),
2597  "could not create a batch normalization forward primitive");
2598  reset(result);
2599  }
2600 
2602  const primitive::at &src, const primitive::at &mean,
2603  const primitive::at &variance, const memory &dst) {
2604  mkldnn_primitive_t result;
2605  mkldnn_primitive_at_t inputs[] = { src.data,
2606  mean.data, variance.data };
2607  const_mkldnn_primitive_t outputs[] = { dst.get() };
2608  check_num_parameters(aprimitive_desc.get(), 3, 1,
2609  "batch normalization forward");
2611  aprimitive_desc.get(), inputs, outputs),
2612  "could not create a batch normalization forward primitive");
2613  reset(result);
2614  }
2615 
2624  const primitive::at &src, const primitive::at &weights,
2625  const memory &dst, const memory &mean, const memory &variance) {
2626  mkldnn_primitive_t result;
2627  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2628  const_mkldnn_primitive_t outputs[] = { dst.get(),
2629  mean.get(), variance.get() };
2630  check_num_parameters(aprimitive_desc.get(), 2, 3,
2631  "batch normalization forward");
2633  aprimitive_desc.get(), inputs, outputs),
2634  "could not create a batch normalization forward primitive");
2635  reset(result);
2636  }
2637 
2639  const primitive::at &src, const primitive::at &weights,
2640  const memory &dst, const memory &mean, const memory &variance,
2641  const memory &workspace) {
2642  mkldnn_primitive_t result;
2643  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2644  const_mkldnn_primitive_t outputs[] = { dst.get(),
2645  mean.get(), variance.get(), workspace.get() };
2646  check_num_parameters(aprimitive_desc.get(), 2, 4,
2647  "batch normalization forward");
2649  aprimitive_desc.get(), inputs, outputs),
2650  "could not create a batch normalization forward primitive");
2651  reset(result);
2652  }
2653 
2655  const primitive::at &src, const memory &dst, const memory &mean,
2656  const memory &variance) {
2657  mkldnn_primitive_t result;
2658  mkldnn_primitive_at_t inputs[] = { src.data };
2659  const_mkldnn_primitive_t outputs[] = { dst.get(),
2660  mean.get(), variance.get() };
2661  check_num_parameters(aprimitive_desc.get(), 1, 3,
2662  "batch normalization forward");
2664  aprimitive_desc.get(), inputs, outputs),
2665  "could not create a batch normalization forward primitive");
2666  reset(result);
2667  }
2668 
2681  const primitive::at &src, const memory &dst, const memory &mean,
2682  const memory &variance, const memory &workspace) {
2683  mkldnn_primitive_t result;
2684  mkldnn_primitive_at_t inputs[2] = { src.data };
2685  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2686  mean.get(), variance.get(), workspace.get() };
2687 
2688  if (1) { // check whether this is the `wrong` constructor
2689  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2690  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2691  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2692  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2693  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2694  // shift parameters, get rid of workspace, and add weights...
2695  auto _weights = dst;
2696  inputs[1] = {_weights.get(), 0};
2697 
2698  auto _dst = mean, _mean = variance, _variance = workspace;
2699  outputs[0] = _dst.get();
2700  outputs[1] = _mean.get();
2701  outputs[2] = _variance.get();
2702  outputs[3] = nullptr;
2703  }
2704  }
2706  aprimitive_desc.get(), inputs, outputs),
2707  "could not create a batch normalization forward primitive");
2708  reset(result);
2709  }
2710 
2712  const primitive::at &src, const primitive::at &weights,
2713  const memory &dst) {
2714  mkldnn_primitive_t result;
2715  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2716  const_mkldnn_primitive_t outputs[] = { dst.get() };
2717  check_num_parameters(aprimitive_desc.get(), 2, 1,
2718  "batch normalization forward");
2720  aprimitive_desc.get(), inputs, outputs),
2721  "could not create a batch normalization forward primitive");
2722  reset(result);
2723  }
2724 
2726  const primitive::at &src, const memory &dst) {
2727  mkldnn_primitive_t result;
2728  mkldnn_primitive_at_t inputs[] = { src.data };
2729  const_mkldnn_primitive_t outputs[] = { dst.get() };
2730  check_num_parameters(aprimitive_desc.get(), 1, 1,
2731  "batch normalization forward");
2733  aprimitive_desc.get(), inputs, outputs),
2734  "could not create a batch normalization forward primitive");
2735  reset(result);
2736  }
2737 };
2738 
2740  struct desc {
2742  template <typename T>
2743  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2744  const memory::desc &data_desc, T epsilon, unsigned flags) {
2747  mkldnn::convert_to_c(aprop_kind),
2748  &diff_data_desc.data, &data_desc.data,
2749  static_cast<float>(epsilon), flags),
2750  "could not create a batch normalization backward descriptor");
2751  }
2752  };
2753 
2755  primitive_desc(const desc &desc, const engine &e,
2757  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2758 
2759  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2761  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2762 
2763  REG_QUERY_MPD(src, src, 0);
2764  REG_QUERY_MPD(mean, src, 1);
2765  REG_QUERY_MPD(variance, src, 2);
2766  REG_QUERY_MPD(weights, weights, 0);
2767  REG_QUERY_MPD(dst, dst, 0);
2768  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2769  REG_QUERY_MPD(workspace, workspace, 0);
2770 
2771  REG_QUERY_MPD(diff_src, diff_src, 0);
2772  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2773  };
2774 
2775  // Prop_kind == backward
2777  const primitive::at &src, const primitive::at &mean,
2778  const primitive::at &variance, const primitive::at &diff_dst,
2779  const primitive::at &weights, const memory &diff_src,
2780  const memory &diff_weights) {
2781  mkldnn_primitive_t result;
2782  mkldnn_primitive_at_t inputs[] = { src.data,
2783  mean.data, variance.data, diff_dst.data, weights.data };
2784  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2785  diff_weights.get() };
2786  check_num_parameters(aprimitive_desc.get(), 5, 2,
2787  "batch normalization backward");
2789  aprimitive_desc.get(), inputs, outputs),
2790  "could not create a batch normalization backward primitive");
2791  reset(result);
2792  }
2793 
2794  // Prop_kind == backward (+ws)
2796  const primitive::at &src, const primitive::at &mean,
2797  const primitive::at &variance, const primitive::at &diff_dst,
2798  const primitive::at &weights, const primitive::at &workspace,
2799  const memory &diff_src, const memory &diff_weights) {
2800  mkldnn_primitive_t result;
2801  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2802  diff_dst.data, weights.data, workspace.data };
2803  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2804  diff_weights.get() };
2805  check_num_parameters(aprimitive_desc.get(), 6, 2,
2806  "batch normalization backward");
2808  aprimitive_desc.get(), inputs, outputs),
2809  "could not create a batch normalization backward primitive");
2810  reset(result);
2811  }
2812 
2813  // Prop_kind == backward_data (+ws or +weights)
2818  const primitive::at &src, const primitive::at &mean,
2819  const primitive::at &variance,const primitive::at &diff_dst,
2820  const primitive::at &weights_or_workspace, const memory &diff_src) {
2821  mkldnn_primitive_t result;
2822  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2823  diff_dst.data, weights_or_workspace.data };
2824  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2825  check_num_parameters(aprimitive_desc.get(), 5, 1,
2826  "batch normalization backward");
2828  aprimitive_desc.get(), inputs, outputs),
2829  "could not create a batch normalization backward primitive");
2830  reset(result);
2831  }
2832 
2833  // Prop_kind == backward_data
2835  const primitive::at &src, const primitive::at &mean,
2836  const primitive::at &variance, const primitive::at &diff_dst,
2837  const memory &diff_src) {
2838  mkldnn_primitive_t result;
2839  mkldnn_primitive_at_t inputs[] = { src.data,
2840  mean.data, variance.data, diff_dst.data };
2841  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2842  check_num_parameters(aprimitive_desc.get(), 4, 1,
2843  "batch normalization backward");
2845  aprimitive_desc.get(), inputs, outputs),
2846  "could not create a batch normalization backward primitive");
2847  reset(result);
2848  }
2849 };
2850 
2852 
2858 
2860  struct desc {
2862  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2863  const memory::desc &weights_desc,
2864  const memory::desc &bias_desc,
2865  const memory::desc &dst_desc) {
2868  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2869  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2870  "could not create a inner product forward descriptor");
2871  }
2872 
2873  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2874  const memory::desc &weights_desc,
2875  const memory::desc &dst_desc) {
2878  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2879  &weights_desc.data, nullptr, &dst_desc.data),
2880  "could not create a inner product forward descriptor");
2881  }
2882  };
2883 
2885  primitive_desc(const desc &desc, const engine &e)
2886  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2887 
2888  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2889  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2890 
2891  REG_QUERY_MPD(src, src, 0);
2892  REG_QUERY_MPD(weights, weights, 0);
2893  REG_QUERY_MPD(bias, weights, 1);
2894  REG_QUERY_MPD(dst, dst, 0);
2895  };
2896 
2897  inner_product_forward(const primitive_desc &aprimitive_desc,
2898  const primitive::at &src, const primitive::at weights,
2899  const primitive::at &bias, const memory &dst) {
2900  mkldnn_primitive_t result;
2901  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2902  bias.data };
2903  const_mkldnn_primitive_t outputs[] = { dst.get() };
2904  check_num_parameters(aprimitive_desc.get(), 3, 1,
2905  "inner product forward");
2907  aprimitive_desc.get(), inputs, outputs),
2908  "could not create a inner product forward primitive");
2909  reset(result);
2910  }
2911 
2912  inner_product_forward(const primitive_desc &aprimitive_desc,
2913  const primitive::at &src, const primitive::at weights,
2914  const memory &dst) {
2915  mkldnn_primitive_t result;
2916  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2917  const_mkldnn_primitive_t outputs[] = { dst.get() };
2918  check_num_parameters(aprimitive_desc.get(), 2, 1,
2919  "inner product forward");
2921  aprimitive_desc.get(), inputs, outputs),
2922  "could not create a inner product forward primitive");
2923  reset(result);
2924  }
2925 };
2926 
2928  struct desc {
2930  desc(const memory::desc &diff_src_desc,
2931  const memory::desc &weights_desc,
2932  const memory::desc &diff_dst_desc) {
2935  &diff_src_desc.data, &weights_desc.data,
2936  &diff_dst_desc.data),
2937  "could not create a inner product backward data descriptor");
2938  }
2939  };
2940 
2942  primitive_desc(const desc &desc, const engine &e,
2943  const inner_product_forward::primitive_desc &hint_fwd_pd)
2944  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2945 
2946  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2947  const inner_product_forward::primitive_desc &hint_fwd_pd)
2948  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2949 
2950  REG_QUERY_MPD(diff_src, diff_src, 0);
2951  REG_QUERY_MPD(weights, weights, 0);
2952  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2953  };
2954 
2956  const primitive::at &diff_dst, const primitive::at weights,
2957  const memory &diff_src) {
2958  mkldnn_primitive_t result;
2959  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2960  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2961  check_num_parameters(aprimitive_desc.get(), 2, 1,
2962  "inner product backward data");
2964  aprimitive_desc.get(), inputs, outputs),
2965  "could not create a inner product backward data primitive");
2966  reset(result);
2967  }
2968 };
2969 
2971  struct desc {
2973  desc(const memory::desc &src_desc,
2974  const memory::desc &diff_weights_desc,
2975  const memory::desc &diff_bias_desc,
2976  const memory::desc &diff_dst_desc) {
2979  &data, &src_desc.data, &diff_weights_desc.data,
2980  &diff_bias_desc.data, &diff_dst_desc.data),
2981  "could not create a inner product backward weights descriptor");
2982  }
2983  desc(const memory::desc &src_desc,
2984  const memory::desc &diff_weights_desc,
2985  const memory::desc &diff_dst_desc) {
2988  &data, &src_desc.data, &diff_weights_desc.data,
2989  nullptr, &diff_dst_desc.data),
2990  "could not create a inner product backward weights descriptor");
2991  }
2992  };
2993 
2995  primitive_desc(const desc &desc, const engine &e,
2996  const inner_product_forward::primitive_desc &hint_fwd_pd)
2997  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2998 
2999  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3000  const inner_product_forward::primitive_desc &hint_fwd_pd)
3001  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3002 
3003  REG_QUERY_MPD(src, src, 0);
3004  REG_QUERY_MPD(diff_weights, diff_weights, 0);
3005  REG_QUERY_MPD(diff_bias, diff_weights, 1);
3006  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3007  };
3008 
3010  const primitive::at &src, const primitive::at diff_dst,
3011  const memory &diff_weights) {
3012  mkldnn_primitive_t result;
3013  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3014  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
3015  check_num_parameters(aprimitive_desc.get(), 2, 1,
3016  "inner product backward weights");
3018  aprimitive_desc.get(), inputs, outputs),
3019  "could not create a inner product backward weights primitive");
3020  reset(result);
3021  }
3022 
3024  const primitive::at &src, const primitive::at diff_dst,
3025  const memory &diff_weights, const memory &diff_bias) {
3026  mkldnn_primitive_t result;
3027  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
3028  const_mkldnn_primitive_t outputs[] =
3029  { diff_weights.get(), diff_bias.get()};
3030  check_num_parameters(aprimitive_desc.get(), 2, 2,
3031  "inner product backward weights");
3033  aprimitive_desc.get(), inputs, outputs),
3034  "could not create a inner product backward weights primitive");
3035  reset(result);
3036  }
3037 };
3038 
3040 
3046 
3047 struct rnn_cell {
3048  struct desc {
3050 
3051  desc(algorithm kind, algorithm activation_f) {
3053  mkldnn::convert_to_c(kind),
3054  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3055  "could not init an rnn cell descriptor");
3056  }
3058 
3059  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3060 
3062  { return algorithm(c_rnn_cell_.cell_kind); }
3064  { return algorithm(c_rnn_cell_.activation_kind); }
3065 
3066  float get_alpha() const { return c_rnn_cell_.alpha; }
3067  void set_alpha(float alpha) {
3068  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3069  c_rnn_cell_.alpha = alpha;
3070  }
3071 
3072  float get_clipping() const { return c_rnn_cell_.clipping; }
3073  void set_clipping(float clipping) {
3074  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3075  c_rnn_cell_.clipping = clipping;
3076  }
3077 
3078  int get_gates_count() const {
3079  return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3080  }
3081  int get_state_count() const {
3082  return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3083  }
3084  };
3085 };
3086 
3087 struct rnn_forward : public primitive {
3088  struct desc {
3090  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3091  const rnn_direction direction,
3092  const memory::desc &src_layer_desc,
3093  const memory::desc &src_iter_desc,
3094  const memory::desc &weights_layer_desc,
3095  const memory::desc &weights_iter_desc,
3096  const memory::desc &bias_desc,
3097  const memory::desc &dst_layer_desc,
3098  const memory::desc &dst_iter_desc
3099  ) {
3101  mkldnn::convert_to_c(aprop_kind), cell,
3102  mkldnn::convert_to_c(direction),
3103  &src_layer_desc.data, &src_iter_desc.data,
3104  &weights_layer_desc.data, &weights_iter_desc.data,
3105  &bias_desc.data,
3106  &dst_layer_desc.data, &dst_iter_desc.data),
3107  "could not create an RNN forward descriptor");
3108  }
3109 
3110  };
3111 
3113  primitive_desc(const desc &desc, const engine &e)
3114  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3115 
3116  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3117  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3118 
3119  REG_QUERY_MPD(src_layer, src, 0);
3120  REG_QUERY_MPD(src_iter, src, 1);
3121  REG_QUERY_MPD(weights_layer, weights, 0);
3122  REG_QUERY_MPD(weights_iter, weights, 1);
3123  REG_QUERY_MPD(bias, weights, 2);
3124  REG_QUERY_MPD(dst_layer, dst, 0);
3125  REG_QUERY_MPD(dst_iter, dst, 1);
3126  REG_QUERY_MPD(workspace, workspace, 0);
3127  };
3128 
3129  rnn_forward(const primitive_desc &aprimitive_desc,
3130  const primitive::at &src_layer, const primitive::at &src_iter,
3131  const primitive::at &weights_layer,
3132  const primitive::at &weights_iter, const primitive::at &bias,
3133  const memory &dst_layer, const memory &dst_iter,
3134  const memory &workspace) {
3135  mkldnn_primitive_t result;
3136  mkldnn_primitive_at_t inputs[5];
3137  const_mkldnn_primitive_t outputs[3];
3138  int idx=0;
3139  inputs[idx++] = src_layer.data;
3140  if (!is_null_memory(src_iter.data.primitive))
3141  inputs[idx++] = src_iter.data;
3142  inputs[idx++] = weights_layer.data;
3143  inputs[idx++] = weights_iter.data;
3144  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3145 
3146  idx=0;
3147  outputs[idx++] = dst_layer.get();
3148  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3149  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3150 
3152  aprimitive_desc.get(), inputs, outputs),
3153  "could not create an RNN forward primitive");
3154  reset(result);
3155  }
3156 };
3157 
3158 struct rnn_backward : public primitive {
3159  struct desc {
3161  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3162  const rnn_direction direction,
3163  const memory::desc &src_layer_desc,
3164  const memory::desc &src_iter_desc,
3165  const memory::desc &weights_layer_desc,
3166  const memory::desc &weights_iter_desc,
3167  const memory::desc &bias_desc,
3168  const memory::desc &dst_layer_desc,
3169  const memory::desc &dst_iter_desc,
3170  const memory::desc &diff_src_layer_desc,
3171  const memory::desc &diff_src_iter_desc,
3172  const memory::desc &diff_weights_layer_desc,
3173  const memory::desc &diff_weights_iter_desc,
3174  const memory::desc &diff_bias_desc,
3175  const memory::desc &diff_dst_layer_desc,
3176  const memory::desc &diff_dst_iter_desc) {
3178  mkldnn::convert_to_c(aprop_kind), cell,
3179  mkldnn::convert_to_c(direction),
3180  &src_layer_desc.data, &src_iter_desc.data,
3181  &weights_layer_desc.data, &weights_iter_desc.data,
3182  &bias_desc.data,
3183  &dst_layer_desc.data, &dst_iter_desc.data,
3184  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3185  &diff_weights_layer_desc.data,
3186  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3187  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3188  "could not create an RNN backward descriptor");
3189  }
3190 
3191  };
3192 
3194  MKLDNN_DEPRECATED
3195  primitive_desc(const desc &desc, const engine &e)
3196  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3197 
3198  primitive_desc(const desc &desc, const engine &e,
3199  const rnn_forward::primitive_desc &hint_fwd_pd)
3200  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3201 
3202  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3203  const rnn_forward::primitive_desc &hint_fwd_pd)
3204  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3205 
3206  REG_QUERY_MPD(src_layer, src, 0);
3207  REG_QUERY_MPD(src_iter, src, 1);
3208  REG_QUERY_MPD(weights_layer, weights, 0);
3209  REG_QUERY_MPD(weights_iter, weights, 1);
3210  REG_QUERY_MPD(bias, weights, 2);
3211  REG_QUERY_MPD(dst_layer, dst, 0);
3212  REG_QUERY_MPD(dst_iter, dst, 1);
3213  REG_QUERY_MPD(workspace, workspace, 0);
3214 
3215  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3216  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3217  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3218  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3219  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3220  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3221  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3222  };
3223 
3224  // With last iteration (with and without input src_iter)
3225  rnn_backward(const primitive_desc &aprimitive_desc,
3226  const primitive::at &src_layer,
3227  const primitive::at &src_iter,
3228  const primitive::at &weights_layer,
3229  const primitive::at &weights_iter,
3230  const primitive::at &bias,
3231  const primitive::at &dst_layer,
3232  const primitive::at &dst_iter,
3233  const memory &diff_src_layer,
3234  const memory &diff_src_iter,
3235  const memory &diff_weights_layer,
3236  const memory &diff_weights_iter,
3237  const memory &diff_bias,
3238  const primitive::at &diff_dst_layer,
3239  const primitive::at &diff_dst_iter,
3240  const primitive::at &workspace) {
3241  mkldnn_primitive_t result;
3242  mkldnn_primitive_at_t inputs[10];
3243  const_mkldnn_primitive_t outputs[5];
3244  int idx=0;
3245  inputs[idx++] = src_layer.data;
3246  if (!is_null_memory(src_iter.data.primitive))
3247  inputs[idx++] = src_iter.data;
3248  inputs[idx++] = weights_layer.data;
3249  inputs[idx++] = weights_iter.data;
3250  if (!is_null_memory(bias.data.primitive))
3251  inputs[idx++] = bias.data;
3252  inputs[idx++] = dst_layer.data;
3253  if (!is_null_memory(dst_iter.data.primitive))
3254  inputs[idx++] = dst_iter.data;
3255  inputs[idx++] = diff_dst_layer.data;
3256  if (!is_null_memory(diff_dst_iter.data.primitive))
3257  inputs[idx++] = diff_dst_iter.data;
3258  inputs[idx++] = workspace.data;
3259 
3260  idx = 0;
3261  outputs[idx++] = diff_src_layer.get();
3262  if (!is_null_memory(diff_src_iter.get()))
3263  outputs[idx++] = diff_src_iter.get();
3264  outputs[idx++] = diff_weights_layer.get();
3265  outputs[idx++] = diff_weights_iter.get();
3266  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3268  aprimitive_desc.get(), inputs, outputs),
3269  "could not create an RNN backward primitive");
3270  reset(result);
3271  }
3272 };
3273 
3275 
3281 
3282 struct shuffle_forward : public primitive {
3283  struct desc {
3285  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3286  int axis, int group_size) {
3288  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3289  axis, group_size),
3290  "could not create a shuffle forward descriptor");
3291  }
3292  };
3293 
3295  primitive_desc(const desc &desc, const engine &e)
3296  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3297 
3298  REG_QUERY_MPD(src, src, 0);
3299  REG_QUERY_MPD(dst, dst, 0);
3300  };
3301 
3302  shuffle_forward(const primitive_desc &aprimitive_desc,
3303  const primitive::at &src, const memory &dst) {
3304  mkldnn_primitive_t result;
3305  mkldnn_primitive_at_t inputs[] = { src.data };
3306  const_mkldnn_primitive_t outputs[] = { dst.get() };
3307  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3309  aprimitive_desc.get(), inputs, outputs),
3310  "could not create a shuffle forward primitive");
3311  reset(result);
3312  }
3313 };
3314 
3315 struct shuffle_backward : public primitive {
3316  struct desc {
3318  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3320  &diff_data_desc.data, axis, group_size),
3321  "could not create a shuffle backward descriptor");
3322  }
3323  };
3324 
3326  primitive_desc(const desc &desc, const engine &e,
3327  const shuffle_forward::primitive_desc &hint_fwd_pd)
3328  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3329 
3330  REG_QUERY_MPD(diff_src, diff_src, 0);
3331  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3332  };
3333 
3334  shuffle_backward(const primitive_desc &aprimitive_desc,
3335  const primitive::at &diff_dst, const memory &diff_src) {
3336  mkldnn_primitive_t result;
3337  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3338  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3339  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3341  aprimitive_desc.get(), inputs, outputs),
3342  "could not create a shuffle backward primitive");
3343  reset(result);
3344  }
3345 };
3346 
3348 
3350 
3356 
3357 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3358 template <> struct handle_traits<mkldnn_stream_t> {
3359  static constexpr auto destructor = &mkldnn_stream_destroy;
3360 };
3361 #endif
3362 
3363 struct stream: public handle<mkldnn_stream_t> {
3364  using handle::handle;
3365 
3369 
3371  return static_cast<mkldnn_stream_kind_t>(akind);
3372  }
3374  stream(kind akind) {
3375  mkldnn_stream_t astream;
3377  convert_to_c(akind)),
3378  "could not create a stream");
3379  reset(astream);
3380  }
3381 
3386  stream &submit(std::vector<primitive> primitives) {
3387  // TODO: find a proper way to convert vector<primitive> to
3388  // vector<mkldnn_primitive_t>
3389  if (primitives.size() == 0) return *this;
3390  std::vector<mkldnn_primitive_t> c_api_primitives;
3391  c_api_primitives.reserve(primitives.size());
3392  auto convert_to_c = [](primitive p) { return p.get(); };
3393  std::transform(primitives.begin(), primitives.end(),
3394  std::back_inserter(c_api_primitives), convert_to_c);
3395 
3396  mkldnn_primitive_t c_api_error_primitive;
3398  mkldnn_stream_submit(get(),
3399  c_api_primitives.size(), &c_api_primitives[0],
3400  &c_api_error_primitive),
3401  "could not submit primitives to a stream",
3402  &c_api_error_primitive);
3403 
3404  return *this;
3405  }
3406 
3413  bool wait(bool block = true) {
3414  mkldnn_primitive_t c_api_error_primitive;
3415  mkldnn_status_t status = mkldnn_stream_wait(get(),
3416  block, &c_api_error_primitive);
3417  if (status != mkldnn_success
3418  && status != mkldnn_try_again)
3419  error::wrap_c_api(status, "could not wait on a stream",
3420  &c_api_error_primitive);
3421  return (status == mkldnn_success);
3422  }
3423 
3425  mkldnn_primitive_t c_api_error_primitive;
3427  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3428  "could not rerun a stream", &c_api_error_primitive);
3429  return *this;
3430  }
3431 };
3432 
3433 #undef REG_QUERY_MPD
3434 
3436 
3438 
3439 } // namespace mkldnn
3440 
3441 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:389
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2473
Definition: mkldnn.hpp:2422
LRN within a single channel.
Definition: mkldnn_types.h:474
primitive error_primitive
Definition: mkldnn.hpp:166
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:809
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1494
Definition: mkldnn.hpp:346
blocked weights format
Definition: mkldnn_types.h:300
Definition: mkldnn.hpp:1681
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2912
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2247
Definition: mkldnn.hpp:270
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1049
blocked weights format
Definition: mkldnn_types.h:303
op descriptor
Definition: mkldnn_types.h:1163
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1059
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1637
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:327
Definition: mkldnn.hpp:3158
blocked weights format
Definition: mkldnn_types.h:287
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
blocked weights format
Definition: mkldnn_types.h:345
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
A Softmax primitive.
Definition: mkldnn_types.h:422
number of outputs expected
Definition: mkldnn_types.h:1152
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3116
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1647
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2585
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3386
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:767
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1256
Definition: mkldnn.hpp:2280
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:39
stream & rerun()
Definition: mkldnn.hpp:3424
Definition: mkldnn.hpp:2243
A descriptor of a convolution operation.
Definition: mkldnn_types.h:655
Definition: mkldnn.hpp:302
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3285
Definition: mkldnn.hpp:2218
The operation failed and should be retried.
Definition: mkldnn_types.h:45
memory null_memory(engine eng)
Definition: mkldnn.hpp:863
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
MKLDNN_DEPRECATED primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3195
blocked weights format
Definition: mkldnn_types.h:259
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:331
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1587
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1721
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:239
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:163
Definition: mkldnn.hpp:266
padding_kind
Definition: mkldnn.hpp:234
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:47
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:381
Definition: mkldnn.hpp:2081
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1549
Backward data propagation.
Definition: mkldnn_types.h:387
Definition: mkldnn.hpp:2497
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:577
Definition: mkldnn.hpp:3325
Definition: mkldnn.hpp:3315
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2463
Definition: mkldnn.hpp:275
blocked weights format
Definition: mkldnn_types.h:283
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:137
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:212
MKLDNN_DEPRECATED desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, T negative_slope)
Definition: mkldnn.hpp:2417
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1100
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:757
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2042
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:925
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:552
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:596
Definition: mkldnn.hpp:291
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:184
MKLDNN_DEPRECATED primitive_desc(std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1191
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:759
blocked weights format
Definition: mkldnn_types.h:348
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2483
blocked weights format
Definition: mkldnn_types.h:349
blocked data format
Definition: mkldnn_types.h:246
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:246
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:867
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:575
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:215
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3295
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor .
batch normalization descriptor
Definition: mkldnn_types.h:1173
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1766
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:932
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:414
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1914
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2144
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:529
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:532
engine get_engine()
Definition: mkldnn.hpp:1269
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:723
blocked data format
Definition: mkldnn_types.h:247
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind...
Definition: mkldnn.hpp:227
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2861
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1221
An execution engine.
Definition: mkldnn.hpp:494
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:813
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2929
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha and beta (...
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:190
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2282
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:398
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:430
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:868
Round down.
Definition: mkldnn_types.h:82
Definition: mkldnn_types.h:1175
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:199
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2512
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1784
Definition: mkldnn.hpp:265
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:430
primitive_attr()
Definition: mkldnn.hpp:423
Definition: mkldnn_types.h:470
Definition: mkldnn.hpp:2400
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2508
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2498
Definition: mkldnn.hpp:2472
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:394
Definition: mkldnn.hpp:249
32-bit signed integer.
Definition: mkldnn_types.h:68
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2942
Max pooling.
Definition: mkldnn_types.h:465
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1408
memory::desc zero_md()
Definition: mkldnn.hpp:857
Definition: mkldnn.hpp:340
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:992
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible value are mkldnn_forward...
blocked weights format
Definition: mkldnn_types.h:273
const post_ops get_post_ops() const
Definition: mkldnn.hpp:464
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2220
Definition: mkldnn.hpp:333
execution engine
Definition: mkldnn_types.h:1148
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3374
Definition: mkldnn.hpp:991
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:338
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2930
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2217
blocked weights format
Definition: mkldnn_types.h:280
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1693
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:852
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2427
Definition: mkldnn.hpp:322
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:910
A descriptor of a convolution followed by relu operation.
Definition: mkldnn_types.h:896
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:257
input memory primitive desc
Definition: mkldnn_types.h:1180
blocked weights format
Definition: mkldnn_types.h:294
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3284
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:203
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1108
Definition: mkldnn.hpp:290
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3129
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:717
rnn descriptor
Definition: mkldnn_types.h:1176
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2571
An element-wise primitive.
Definition: mkldnn_types.h:418
Definition: mkldnn.hpp:2496
destination grad.
Definition: mkldnn_types.h:1187
algorithm get_cell_kind() const
Definition: mkldnn.hpp:3061
engine get_engine()
Definition: mkldnn.hpp:1218
Definition: mkldnn.hpp:2401
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:920
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1205
blocked weights format
Definition: mkldnn_types.h:297
A descriptor for an rnn operation.
Definition: mkldnn_types.h:947
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1385
Definition: mkldnn.hpp:1047
Definition: mkldnn.hpp:278
Definition: mkldnn.hpp:260
eltwise descriptor
Definition: mkldnn_types.h:1168
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2680
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1433
Definition: mkldnn.hpp:277
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2817
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2129
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2973
batch_normalization_flag
Definition: mkldnn.hpp:289
A memory primitive.
Definition: mkldnn_types.h:400
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:928
MKLDNN_DEPRECATED desc(prop_kind aprop_kind, const memory::desc &src_desc, T negative_slope)
Definition: mkldnn.hpp:2369
blocked weights format
Definition: mkldnn_types.h:282
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3161
Eltwise: soft_relu.
Definition: mkldnn_types.h:461
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:473
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2897
Definition: mkldnn.hpp:345
Definition: mkldnn.hpp:262
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
RNN cell.
Definition: mkldnn_types.h:480
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2244
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1805
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:888
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2946
Definition: mkldnn.hpp:371
blocked weights format
Definition: mkldnn_types.h:309
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1345
Backward weights propagation.
Definition: mkldnn_types.h:389
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:437
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3089
blocked weights format
Definition: mkldnn_types.h:344
eltwise_forward relu_forward
Definition: mkldnn.hpp:2398
32-bit/single-precision floating point.
Definition: mkldnn_types.h:66
blocked weights format
Definition: mkldnn_types.h:256
blocked data format
Definition: mkldnn_types.h:245
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1569
algorithm get_activation() const
Definition: mkldnn.hpp:3063
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2255
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:172
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:359
Omit statistics.
Definition: mkldnn_types.h:532
Memory descriptor.
Definition: mkldnn_types.h:616
Definition: mkldnn.hpp:2860
Definition: mkldnn.hpp:305
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3302
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:497
void set_clipping(float clipping)
Definition: mkldnn.hpp:3073
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1661
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2082
Definition: mkldnn.hpp:2859
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2547
Definition: mkldnn.hpp:281
pooling descriptor
Definition: mkldnn_types.h:1171
Definition: mkldnn.hpp:2281
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:242
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2219
Definition: mkldnn.hpp:268
blocked weights format
Definition: mkldnn_types.h:255
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:181
blocked weights format
Definition: mkldnn_types.h:308
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:922
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:151
blocked weights format
Definition: mkldnn_types.h:285
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1531
The operation was successful.
Definition: mkldnn_types.h:41
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:320
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2999
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1633
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:3051
blocked weights format
Definition: mkldnn_types.h:328
Definition: mkldnn.hpp:328
Definition: mkldnn.hpp:247
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1257
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:342
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3160
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:385
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:169
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:3023
softmax descriptor
Definition: mkldnn_types.h:1170
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:78
A deconvolution primitive.
Definition: mkldnn_types.h:416
Definition: mkldnn.hpp:332
Definition: mkldnn.hpp:276
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:747
Use global statistics.
Definition: mkldnn_types.h:510
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1072
blocked weights format
Definition: mkldnn_types.h:286
no query
Definition: mkldnn_types.h:1146
Definition: mkldnn.hpp:1745
blocked weights format
Definition: mkldnn_types.h:335
blocked weights format
Definition: mkldnn_types.h:298
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offset offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:350
Average pooling include padding.
Definition: mkldnn_types.h:467
Unspecified format.
Definition: mkldnn_types.h:140
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2955
Definition: mkldnn.hpp:2103
destination memory primitive desc
Definition: mkldnn_types.h:1186
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2569
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:225
GRU cell with linear before reset.
Definition: mkldnn_types.h:493
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:786
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2181
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:472
blocked weights format
Definition: mkldnn_types.h:270
GRU cell.
Definition: mkldnn_types.h:484
Eager stream.
Definition: mkldnn_types.h:1201
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:942
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:457
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:145
implementation name
Definition: mkldnn_types.h:1159
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1983
Definition: mkldnn.hpp:1346
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3318
Definition: mkldnn.hpp:3316
Definition: mkldnn.hpp:258
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2319
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and pointer to a constant floating point array of output sc...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:178
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:449
kind
Kinds of engines.
Definition: mkldnn.hpp:499
Definition: mkldnn.hpp:2143
Definition: mkldnn.hpp:2927
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2476
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:163
round_mode
Definition: mkldnn.hpp:225
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:897
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1872
Eltwise: ReLU.
Definition: mkldnn_types.h:445
Definition: mkldnn.hpp:2460
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1347
Definition: mkldnn.hpp:235
1D data tensor.
Definition: mkldnn_types.h:146
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:138
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1297
desc(const convolution_forward::desc conv_desc, const float negative_slope)
Definition: mkldnn.hpp:1684
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2759
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3202
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3326
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:190
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2402
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:135
Definition: mkldnn.hpp:990
Eltwise: square.
Definition: mkldnn_types.h:451
Definition: mkldnn.hpp:1124
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1367
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1002
Definition: mkldnn.hpp:282
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:849
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:160
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:843
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2654
Definition: mkldnn.hpp:269
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2145
Backward bias propagation.
Definition: mkldnn_types.h:391
Definition: mkldnn.hpp:931
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2092
blocked weights format
Definition: mkldnn_types.h:339
Use scale and shift parameters.
Definition: mkldnn_types.h:523
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1747
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:313
Definition: mkldnn.hpp:280
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:319
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
float get_alpha() const
Definition: mkldnn.hpp:3066
blocked weights format
Definition: mkldnn_types.h:269
blocked weights format
Definition: mkldnn_types.h:329
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:700
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:406
Definition: mkldnn_types.h:942
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2355
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2028
Definition: mkldnn.hpp:422
blocked weights format
Definition: mkldnn_types.h:337
blocked weights format
Definition: mkldnn_types.h:305
int get_gates_count() const
Definition: mkldnn.hpp:3078
int ndims
Number of dimensions.
Definition: mkldnn_types.h:621
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:955
Definition: mkldnn.hpp:2080
Definition: mkldnn.hpp:1048
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
void set_alpha(float alpha)
Definition: mkldnn.hpp:3067
A convolution primitive merged with ReLU.
Definition: mkldnn_types.h:432
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes a eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff_...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2155
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:166
Definition: mkldnn.hpp:3283
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:218
Definition: mkldnn.hpp:2167
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:763
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1546
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1840
A rnn primitive.
Definition: mkldnn_types.h:434
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
MKLDNN_DEPRECATED primitive_desc(const memory::desc &output, std::vector< double > scale, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1175
blocked weights format
Definition: mkldnn_types.h:293
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis and group number.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1944
Definition: mkldnn.hpp:3048
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2436
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:369
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:136
CPU engine.
Definition: mkldnn_types.h:998
Definition: mkldnn.hpp:293
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2405
Eltwise: square root.
Definition: mkldnn_types.h:455
blocked weights format
Definition: mkldnn_types.h:257
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1197
Definition: mkldnn.hpp:272
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:187
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1106
Winograd convolution.
Definition: mkldnn_types.h:443
Definition: mkldnn.hpp:248
A ReLU primitive.
Definition: mkldnn_types.h:420
Definition: mkldnn.hpp:347
Eltwise: linear.
Definition: mkldnn_types.h:457
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1873
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1945
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:966
Eltwise: logistic.
Definition: mkldnn_types.h:463
Definition: mkldnn.hpp:2739
Direct convolution.
Definition: mkldnn_types.h:441
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:54
Definition: mkldnn.hpp:342
Definition: mkldnn.hpp:271
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2115
source gradient memory primitive desc
Definition: mkldnn_types.h:1183
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:917
Definition: mkldnn.hpp:1474
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2741
Definition: mkldnn_types.h:934
Definition: mkldnn.hpp:314
blocked data format
Definition: mkldnn_types.h:249
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2083
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2711
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:3049
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:772
runtime estimation (seconds)
Definition: mkldnn_types.h:1154
blocked weights format
Definition: mkldnn_types.h:336
bool operator==(const T other) const
Definition: mkldnn.hpp:61
A (in-place) concat primitive.
Definition: mkldnn_types.h:410
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:823
blocked weights format
Definition: mkldnn_types.h:271
LSTM cell.
Definition: mkldnn_types.h:482
blocked weights format
Definition: mkldnn_types.h:260
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn_types.h:943
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2558
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2885
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2888
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:1870
16-bit signed integer.
Definition: mkldnn_types.h:70
Definition: mkldnn.hpp:2354
A shuffle primitive.
Definition: mkldnn_types.h:406
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:278
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3317
primitive_desc()
Definition: mkldnn.hpp:744
int len() const
Definition: mkldnn.hpp:379
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1136
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2873
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
Definition: mkldnn.hpp:244
blocked weights format
Definition: mkldnn_types.h:299
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1475
blocked weights format
Definition: mkldnn_types.h:292
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:408
blocked weights format
Definition: mkldnn_types.h:306
Fuse with ReLU.
Definition: mkldnn_types.h:541
Definition: mkldnn.hpp:261
Definition: mkldnn.hpp:279
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:510
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1145
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:836
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:542
Definition: mkldnn.hpp:3087
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2056
Definition: mkldnn.hpp:292
blocked data format
Definition: mkldnn_types.h:248
A sum primitive.
Definition: mkldnn_types.h:412
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2834
Definition: mkldnn.hpp:304
blocked weights format
Definition: mkldnn_types.h:333
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2385
unsigned flags
Definition: mkldnn_types.h:863
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:261
blocked weights format
Definition: mkldnn_types.h:310
Definition: mkldnn.hpp:3047
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2522
blocked weights format
Definition: mkldnn_types.h:252
Definition: mkldnn.hpp:3088
Definition: mkldnn.hpp:259
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2375
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:338
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:175
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1155
An batch normalization primitive.
Definition: mkldnn_types.h:428
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:439
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:520
Definition: mkldnn.hpp:2353
A descriptor of a pooling operation.
Definition: mkldnn_types.h:775
Definition: mkldnn.hpp:3363
Definition: mkldnn.hpp:273
Definition: mkldnn.hpp:274
engine get_engine()
Definition: mkldnn.hpp:776
MKLDNN_DEPRECATED convolution_relu_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1704
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:175
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2032
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1272
deconvolution descriptor
Definition: mkldnn_types.h:1166
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1126
blocked weights format
Definition: mkldnn_types.h:312
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3334
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:933
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2306
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:716
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1436
engine get_engine()
Definition: mkldnn.hpp:952
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:72
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:365
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2374
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2983
source memory primitive desc
Definition: mkldnn_types.h:1182
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:396
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1918
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2005
RNN packed weights (unused)
Definition: mkldnn_types.h:354
Definition: mkldnn.hpp:3294
Winograd deconvolution.
Definition: mkldnn_types.h:478
Definition: mkldnn.hpp:250
number of inputs expected
Definition: mkldnn_types.h:1151
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2462
Definition: mkldnn.hpp:349
Definition: mkldnn.hpp:3112
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2561
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2357
An unspecified engine.
Definition: mkldnn_types.h:1199
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1828
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:836
A view primitive.
Definition: mkldnn_types.h:402
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3159
Definition: mkldnn.hpp:263
Definition: mkldnn.hpp:330
Definition: mkldnn.hpp:3193
blocked weights format
Definition: mkldnn_types.h:284
Definition: mkldnn.hpp:339
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:156
Definition: mkldnn.hpp:344
Definition: mkldnn.hpp:334
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:336
Average pooling exclude padding.
Definition: mkldnn_types.h:469
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:381
Definition: mkldnn_types.h:913
Forward data propagation (inference mode).
Definition: mkldnn_types.h:379
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:211
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:193
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:585
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2601
Direct deconvolution.
Definition: mkldnn_types.h:476
Eltwise: abs.
Definition: mkldnn_types.h:453
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2623
blocked weights format
Definition: mkldnn_types.h:322
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2331
blocked weights format
Definition: mkldnn_types.h:272
A memory descriptor.
Definition: mkldnn.hpp:713
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1927
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:207
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2378
blocked weights format
Definition: mkldnn_types.h:330
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:900
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:447
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2972
mkldnn_status_t status
Definition: mkldnn.hpp:164
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1855
eltwise_backward relu_backward
Definition: mkldnn.hpp:2450
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1018
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1965
blocked weights format
Definition: mkldnn_types.h:311
2D data tensor.
Definition: mkldnn_types.h:148
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2755
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2862
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3413
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:56
memory descriptor for memory and view
Definition: mkldnn_types.h:1164
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1027
Definition: mkldnn.hpp:267
An LRN primitive.
Definition: mkldnn_types.h:426
Definition: mkldnn_types.h:939
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:363
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3225
Lazy stream.
Definition: mkldnn_types.h:1203
Definition: mkldnn.hpp:335
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2499
blocked weights format
Definition: mkldnn_types.h:334
Definition: mkldnn.hpp:306
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:443
blocked weights format
Definition: mkldnn_types.h:254
desc(algorithm kind)
Definition: mkldnn.hpp:3057
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3198
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:232
blocked weights format
Definition: mkldnn_types.h:304
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:2970
shuffle descriptor
Definition: mkldnn_types.h:1167
Forward data propagation (training mode).
Definition: mkldnn_types.h:375
Definition: mkldnn.hpp:348
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2168
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:3009
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1548
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:782
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:154
engine get_engine()
Definition: mkldnn.hpp:1097
post_ops()
Definition: mkldnn.hpp:372
An opaque structure to describe a primitive.
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2795
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:144
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1348
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:1473
Definition: mkldnn.hpp:327
Definition: mkldnn.hpp:320
convolution descriptor
Definition: mkldnn_types.h:1165
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1522
A memory primitive descriptor.
Definition: mkldnn.hpp:740
Definition: mkldnn.hpp:316
Definition: mkldnn.hpp:2507
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:295
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1518
blocked weights format
Definition: mkldnn_types.h:288
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2725
Eltwise: bounded_relu.
Definition: mkldnn_types.h:459
Definition: mkldnn.hpp:2461
#define REG_QUERY_MPD(name, what, idx)
Definition: mkldnn.hpp:1322
Definition: mkldnn_types.h:936
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1458
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:994
Definition: mkldnn_types.h:909
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:58
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3113
blocked weights format
Definition: mkldnn_types.h:347
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
Memory primitive that describes the data.
Definition: mkldnn.hpp:570
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:351
Definition: mkldnn.hpp:329
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2104
Definition: mkldnn.hpp:2142
Definition: mkldnn.hpp:303
Round nearest.
Definition: mkldnn_types.h:80
blocked weights format
Definition: mkldnn_types.h:346
Definition: mkldnn.hpp:245
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2776
Definition: mkldnn.hpp:1744
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:610
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3370
blocked weights format
Definition: mkldnn_types.h:253
blocked weights format
Definition: mkldnn_types.h:343
Definition: mkldnn.hpp:1942
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1085
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2543
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2267
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1445
A reorder primitive.
Definition: mkldnn_types.h:404
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1831
mkldnn_status_t MKLDNN_API mkldnn_convolution_relu_desc_init(mkldnn_convolution_relu_desc_t *conv_relu_desc, const mkldnn_convolution_desc_t *conv_desc, float negative_slope)
Initializes a merged convolution-relu descriptor conv_relu_desc for forward propagation (supported in...
rnn_direction
Definition: mkldnn.hpp:301
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1155
blocked weights format
Definition: mkldnn_types.h:331
blocked weights format
Definition: mkldnn_types.h:291
An unspecified engine.
Definition: mkldnn_types.h:996
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:736
blocked weights format
Definition: mkldnn_types.h:307
Definition: mkldnn.hpp:1125
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:1015
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2310
mkldnn_convolution_relu_desc_t data
Definition: mkldnn.hpp:1682
blocked weights format
Definition: mkldnn_types.h:332
blocked weights format
Definition: mkldnn_types.h:321
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:438
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2995
Definition: mkldnn.hpp:264
inner product descriptor
Definition: mkldnn_types.h:1174
A pooling primitive.
Definition: mkldnn_types.h:424
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1184
output memory primitive desc
Definition: mkldnn_types.h:1181
Definition: mkldnn.hpp:2305
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:196
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2107
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2545
Definition: mkldnn.hpp:932
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:337
std::string message
Definition: mkldnn.hpp:165
Definition: mkldnn.hpp:3282
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2423
Definition: mkldnn.hpp:317
blocked weights format
Definition: mkldnn_types.h:281
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:383
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:213
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:243
lrn descriptor
Definition: mkldnn_types.h:1172
workspace memory primitive desc
Definition: mkldnn_types.h:1188
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2195
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1609
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1286
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
blocked weights format
Definition: mkldnn_types.h:258
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1746
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2743
blocked weights format
Definition: mkldnn_types.h:296
Definition: mkldnn.hpp:226
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:268
Definition: mkldnn_types.h:1169
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2172
float get_clipping() const
Definition: mkldnn.hpp:3072
weights grad.
Definition: mkldnn_types.h:1185
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:157
Definition: mkldnn.hpp:323
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:399
primitive kind
Definition: mkldnn_types.h:1149
blocked data format
Definition: mkldnn_types.h:244
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1891
RNN packed weights (unused)
Definition: mkldnn_types.h:355
int get_state_count() const
Definition: mkldnn.hpp:3081
blocked weights format
Definition: mkldnn_types.h:279
Definition: mkldnn.hpp:319
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2283
A merged convolution-relu primitive for inference mode only.
Definition: mkldnn.hpp:1680
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2638
kind
Definition: mkldnn.hpp:3366
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1476
Definition: mkldnn.hpp:343
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3090
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...